diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 00000000..e69de29b diff --git a/404.html b/404.html new file mode 100644 index 00000000..72959673 --- /dev/null +++ b/404.html @@ -0,0 +1,1598 @@ + + + + + + + + + + + + + + + + + + + + + + + Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ +

404 - Not found

+ +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/assets/_mkdocstrings.css b/assets/_mkdocstrings.css new file mode 100644 index 00000000..b500381b --- /dev/null +++ b/assets/_mkdocstrings.css @@ -0,0 +1,143 @@ + +/* Avoid breaking parameter names, etc. in table cells. */ +.doc-contents td code { + word-break: normal !important; +} + +/* No line break before first paragraph of descriptions. */ +.doc-md-description, +.doc-md-description>p:first-child { + display: inline; +} + +/* Max width for docstring sections tables. */ +.doc .md-typeset__table, +.doc .md-typeset__table table { + display: table !important; + width: 100%; +} + +.doc .md-typeset__table tr { + display: table-row; +} + +/* Defaults in Spacy table style. */ +.doc-param-default { + float: right; +} + +/* Parameter headings must be inline, not blocks. */ +.doc-heading-parameter { + display: inline; +} + +/* Prefer space on the right, not the left of parameter permalinks. */ +.doc-heading-parameter .headerlink { + margin-left: 0 !important; + margin-right: 0.2rem; +} + +/* Backward-compatibility: docstring section titles in bold. */ +.doc-section-title { + font-weight: bold; +} + +/* Symbols in Navigation and ToC. */ +:root, :host, +[data-md-color-scheme="default"] { + --doc-symbol-parameter-fg-color: #df50af; + --doc-symbol-attribute-fg-color: #953800; + --doc-symbol-function-fg-color: #8250df; + --doc-symbol-method-fg-color: #8250df; + --doc-symbol-class-fg-color: #0550ae; + --doc-symbol-module-fg-color: #5cad0f; + + --doc-symbol-parameter-bg-color: #df50af1a; + --doc-symbol-attribute-bg-color: #9538001a; + --doc-symbol-function-bg-color: #8250df1a; + --doc-symbol-method-bg-color: #8250df1a; + --doc-symbol-class-bg-color: #0550ae1a; + --doc-symbol-module-bg-color: #5cad0f1a; +} + +[data-md-color-scheme="slate"] { + --doc-symbol-parameter-fg-color: #ffa8cc; + --doc-symbol-attribute-fg-color: #ffa657; + --doc-symbol-function-fg-color: #d2a8ff; + --doc-symbol-method-fg-color: #d2a8ff; + --doc-symbol-class-fg-color: #79c0ff; + --doc-symbol-module-fg-color: #baff79; + + --doc-symbol-parameter-bg-color: #ffa8cc1a; + --doc-symbol-attribute-bg-color: #ffa6571a; + --doc-symbol-function-bg-color: #d2a8ff1a; + --doc-symbol-method-bg-color: #d2a8ff1a; + --doc-symbol-class-bg-color: #79c0ff1a; + --doc-symbol-module-bg-color: #baff791a; +} + +code.doc-symbol { + border-radius: .1rem; + font-size: .85em; + padding: 0 .3em; + font-weight: bold; +} + +code.doc-symbol-parameter { + color: var(--doc-symbol-parameter-fg-color); + background-color: var(--doc-symbol-parameter-bg-color); +} + +code.doc-symbol-parameter::after { + content: "param"; +} + +code.doc-symbol-attribute { + color: var(--doc-symbol-attribute-fg-color); + background-color: var(--doc-symbol-attribute-bg-color); +} + +code.doc-symbol-attribute::after { + content: "attr"; +} + +code.doc-symbol-function { + color: var(--doc-symbol-function-fg-color); + background-color: var(--doc-symbol-function-bg-color); +} + +code.doc-symbol-function::after { + content: "func"; +} + +code.doc-symbol-method { + color: var(--doc-symbol-method-fg-color); + background-color: var(--doc-symbol-method-bg-color); +} + +code.doc-symbol-method::after { + content: "meth"; +} + +code.doc-symbol-class { + color: var(--doc-symbol-class-fg-color); + background-color: var(--doc-symbol-class-bg-color); +} + +code.doc-symbol-class::after { + content: "class"; +} + +code.doc-symbol-module { + color: var(--doc-symbol-module-fg-color); + background-color: var(--doc-symbol-module-bg-color); +} + +code.doc-symbol-module::after { + content: "mod"; +} + +.doc-signature .autorefs { + color: inherit; + border-bottom: 1px dotted currentcolor; +} diff --git a/assets/images/favicon.png b/assets/images/favicon.png new file mode 100644 index 00000000..1cf13b9f Binary files /dev/null and b/assets/images/favicon.png differ diff --git a/assets/javascripts/bundle.83f73b43.min.js b/assets/javascripts/bundle.83f73b43.min.js new file mode 100644 index 00000000..43d8b70f --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js @@ -0,0 +1,16 @@ +"use strict";(()=>{var Wi=Object.create;var gr=Object.defineProperty;var Di=Object.getOwnPropertyDescriptor;var Vi=Object.getOwnPropertyNames,Vt=Object.getOwnPropertySymbols,Ni=Object.getPrototypeOf,yr=Object.prototype.hasOwnProperty,ao=Object.prototype.propertyIsEnumerable;var io=(e,t,r)=>t in e?gr(e,t,{enumerable:!0,configurable:!0,writable:!0,value:r}):e[t]=r,$=(e,t)=>{for(var r in t||(t={}))yr.call(t,r)&&io(e,r,t[r]);if(Vt)for(var r of Vt(t))ao.call(t,r)&&io(e,r,t[r]);return e};var so=(e,t)=>{var r={};for(var o in e)yr.call(e,o)&&t.indexOf(o)<0&&(r[o]=e[o]);if(e!=null&&Vt)for(var o of Vt(e))t.indexOf(o)<0&&ao.call(e,o)&&(r[o]=e[o]);return r};var xr=(e,t)=>()=>(t||e((t={exports:{}}).exports,t),t.exports);var zi=(e,t,r,o)=>{if(t&&typeof t=="object"||typeof t=="function")for(let n of Vi(t))!yr.call(e,n)&&n!==r&&gr(e,n,{get:()=>t[n],enumerable:!(o=Di(t,n))||o.enumerable});return e};var Mt=(e,t,r)=>(r=e!=null?Wi(Ni(e)):{},zi(t||!e||!e.__esModule?gr(r,"default",{value:e,enumerable:!0}):r,e));var co=(e,t,r)=>new Promise((o,n)=>{var i=p=>{try{s(r.next(p))}catch(c){n(c)}},a=p=>{try{s(r.throw(p))}catch(c){n(c)}},s=p=>p.done?o(p.value):Promise.resolve(p.value).then(i,a);s((r=r.apply(e,t)).next())});var lo=xr((Er,po)=>{(function(e,t){typeof Er=="object"&&typeof po!="undefined"?t():typeof define=="function"&&define.amd?define(t):t()})(Er,function(){"use strict";function e(r){var o=!0,n=!1,i=null,a={text:!0,search:!0,url:!0,tel:!0,email:!0,password:!0,number:!0,date:!0,month:!0,week:!0,time:!0,datetime:!0,"datetime-local":!0};function s(k){return!!(k&&k!==document&&k.nodeName!=="HTML"&&k.nodeName!=="BODY"&&"classList"in k&&"contains"in k.classList)}function p(k){var ft=k.type,qe=k.tagName;return!!(qe==="INPUT"&&a[ft]&&!k.readOnly||qe==="TEXTAREA"&&!k.readOnly||k.isContentEditable)}function c(k){k.classList.contains("focus-visible")||(k.classList.add("focus-visible"),k.setAttribute("data-focus-visible-added",""))}function l(k){k.hasAttribute("data-focus-visible-added")&&(k.classList.remove("focus-visible"),k.removeAttribute("data-focus-visible-added"))}function f(k){k.metaKey||k.altKey||k.ctrlKey||(s(r.activeElement)&&c(r.activeElement),o=!0)}function u(k){o=!1}function d(k){s(k.target)&&(o||p(k.target))&&c(k.target)}function y(k){s(k.target)&&(k.target.classList.contains("focus-visible")||k.target.hasAttribute("data-focus-visible-added"))&&(n=!0,window.clearTimeout(i),i=window.setTimeout(function(){n=!1},100),l(k.target))}function L(k){document.visibilityState==="hidden"&&(n&&(o=!0),X())}function X(){document.addEventListener("mousemove",J),document.addEventListener("mousedown",J),document.addEventListener("mouseup",J),document.addEventListener("pointermove",J),document.addEventListener("pointerdown",J),document.addEventListener("pointerup",J),document.addEventListener("touchmove",J),document.addEventListener("touchstart",J),document.addEventListener("touchend",J)}function te(){document.removeEventListener("mousemove",J),document.removeEventListener("mousedown",J),document.removeEventListener("mouseup",J),document.removeEventListener("pointermove",J),document.removeEventListener("pointerdown",J),document.removeEventListener("pointerup",J),document.removeEventListener("touchmove",J),document.removeEventListener("touchstart",J),document.removeEventListener("touchend",J)}function J(k){k.target.nodeName&&k.target.nodeName.toLowerCase()==="html"||(o=!1,te())}document.addEventListener("keydown",f,!0),document.addEventListener("mousedown",u,!0),document.addEventListener("pointerdown",u,!0),document.addEventListener("touchstart",u,!0),document.addEventListener("visibilitychange",L,!0),X(),r.addEventListener("focus",d,!0),r.addEventListener("blur",y,!0),r.nodeType===Node.DOCUMENT_FRAGMENT_NODE&&r.host?r.host.setAttribute("data-js-focus-visible",""):r.nodeType===Node.DOCUMENT_NODE&&(document.documentElement.classList.add("js-focus-visible"),document.documentElement.setAttribute("data-js-focus-visible",""))}if(typeof window!="undefined"&&typeof document!="undefined"){window.applyFocusVisiblePolyfill=e;var t;try{t=new CustomEvent("focus-visible-polyfill-ready")}catch(r){t=document.createEvent("CustomEvent"),t.initCustomEvent("focus-visible-polyfill-ready",!1,!1,{})}window.dispatchEvent(t)}typeof document!="undefined"&&e(document)})});var qr=xr((hy,On)=>{"use strict";/*! + * escape-html + * Copyright(c) 2012-2013 TJ Holowaychuk + * Copyright(c) 2015 Andreas Lubbe + * Copyright(c) 2015 Tiancheng "Timothy" Gu + * MIT Licensed + */var $a=/["'&<>]/;On.exports=Pa;function Pa(e){var t=""+e,r=$a.exec(t);if(!r)return t;var o,n="",i=0,a=0;for(i=r.index;i{/*! + * clipboard.js v2.0.11 + * https://clipboardjs.com/ + * + * Licensed MIT © Zeno Rocha + */(function(t,r){typeof It=="object"&&typeof Yr=="object"?Yr.exports=r():typeof define=="function"&&define.amd?define([],r):typeof It=="object"?It.ClipboardJS=r():t.ClipboardJS=r()})(It,function(){return function(){var e={686:function(o,n,i){"use strict";i.d(n,{default:function(){return Ui}});var a=i(279),s=i.n(a),p=i(370),c=i.n(p),l=i(817),f=i.n(l);function u(V){try{return document.execCommand(V)}catch(A){return!1}}var d=function(A){var M=f()(A);return u("cut"),M},y=d;function L(V){var A=document.documentElement.getAttribute("dir")==="rtl",M=document.createElement("textarea");M.style.fontSize="12pt",M.style.border="0",M.style.padding="0",M.style.margin="0",M.style.position="absolute",M.style[A?"right":"left"]="-9999px";var F=window.pageYOffset||document.documentElement.scrollTop;return M.style.top="".concat(F,"px"),M.setAttribute("readonly",""),M.value=V,M}var X=function(A,M){var F=L(A);M.container.appendChild(F);var D=f()(F);return u("copy"),F.remove(),D},te=function(A){var M=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body},F="";return typeof A=="string"?F=X(A,M):A instanceof HTMLInputElement&&!["text","search","url","tel","password"].includes(A==null?void 0:A.type)?F=X(A.value,M):(F=f()(A),u("copy")),F},J=te;function k(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?k=function(M){return typeof M}:k=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},k(V)}var ft=function(){var A=arguments.length>0&&arguments[0]!==void 0?arguments[0]:{},M=A.action,F=M===void 0?"copy":M,D=A.container,Y=A.target,$e=A.text;if(F!=="copy"&&F!=="cut")throw new Error('Invalid "action" value, use either "copy" or "cut"');if(Y!==void 0)if(Y&&k(Y)==="object"&&Y.nodeType===1){if(F==="copy"&&Y.hasAttribute("disabled"))throw new Error('Invalid "target" attribute. Please use "readonly" instead of "disabled" attribute');if(F==="cut"&&(Y.hasAttribute("readonly")||Y.hasAttribute("disabled")))throw new Error(`Invalid "target" attribute. You can't cut text from elements with "readonly" or "disabled" attributes`)}else throw new Error('Invalid "target" value, use a valid Element');if($e)return J($e,{container:D});if(Y)return F==="cut"?y(Y):J(Y,{container:D})},qe=ft;function Fe(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?Fe=function(M){return typeof M}:Fe=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},Fe(V)}function ki(V,A){if(!(V instanceof A))throw new TypeError("Cannot call a class as a function")}function no(V,A){for(var M=0;M0&&arguments[0]!==void 0?arguments[0]:{};this.action=typeof D.action=="function"?D.action:this.defaultAction,this.target=typeof D.target=="function"?D.target:this.defaultTarget,this.text=typeof D.text=="function"?D.text:this.defaultText,this.container=Fe(D.container)==="object"?D.container:document.body}},{key:"listenClick",value:function(D){var Y=this;this.listener=c()(D,"click",function($e){return Y.onClick($e)})}},{key:"onClick",value:function(D){var Y=D.delegateTarget||D.currentTarget,$e=this.action(Y)||"copy",Dt=qe({action:$e,container:this.container,target:this.target(Y),text:this.text(Y)});this.emit(Dt?"success":"error",{action:$e,text:Dt,trigger:Y,clearSelection:function(){Y&&Y.focus(),window.getSelection().removeAllRanges()}})}},{key:"defaultAction",value:function(D){return vr("action",D)}},{key:"defaultTarget",value:function(D){var Y=vr("target",D);if(Y)return document.querySelector(Y)}},{key:"defaultText",value:function(D){return vr("text",D)}},{key:"destroy",value:function(){this.listener.destroy()}}],[{key:"copy",value:function(D){var Y=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body};return J(D,Y)}},{key:"cut",value:function(D){return y(D)}},{key:"isSupported",value:function(){var D=arguments.length>0&&arguments[0]!==void 0?arguments[0]:["copy","cut"],Y=typeof D=="string"?[D]:D,$e=!!document.queryCommandSupported;return Y.forEach(function(Dt){$e=$e&&!!document.queryCommandSupported(Dt)}),$e}}]),M}(s()),Ui=Fi},828:function(o){var n=9;if(typeof Element!="undefined"&&!Element.prototype.matches){var i=Element.prototype;i.matches=i.matchesSelector||i.mozMatchesSelector||i.msMatchesSelector||i.oMatchesSelector||i.webkitMatchesSelector}function a(s,p){for(;s&&s.nodeType!==n;){if(typeof s.matches=="function"&&s.matches(p))return s;s=s.parentNode}}o.exports=a},438:function(o,n,i){var a=i(828);function s(l,f,u,d,y){var L=c.apply(this,arguments);return l.addEventListener(u,L,y),{destroy:function(){l.removeEventListener(u,L,y)}}}function p(l,f,u,d,y){return typeof l.addEventListener=="function"?s.apply(null,arguments):typeof u=="function"?s.bind(null,document).apply(null,arguments):(typeof l=="string"&&(l=document.querySelectorAll(l)),Array.prototype.map.call(l,function(L){return s(L,f,u,d,y)}))}function c(l,f,u,d){return function(y){y.delegateTarget=a(y.target,f),y.delegateTarget&&d.call(l,y)}}o.exports=p},879:function(o,n){n.node=function(i){return i!==void 0&&i instanceof HTMLElement&&i.nodeType===1},n.nodeList=function(i){var a=Object.prototype.toString.call(i);return i!==void 0&&(a==="[object NodeList]"||a==="[object HTMLCollection]")&&"length"in i&&(i.length===0||n.node(i[0]))},n.string=function(i){return typeof i=="string"||i instanceof String},n.fn=function(i){var a=Object.prototype.toString.call(i);return a==="[object Function]"}},370:function(o,n,i){var a=i(879),s=i(438);function p(u,d,y){if(!u&&!d&&!y)throw new Error("Missing required arguments");if(!a.string(d))throw new TypeError("Second argument must be a String");if(!a.fn(y))throw new TypeError("Third argument must be a Function");if(a.node(u))return c(u,d,y);if(a.nodeList(u))return l(u,d,y);if(a.string(u))return f(u,d,y);throw new TypeError("First argument must be a String, HTMLElement, HTMLCollection, or NodeList")}function c(u,d,y){return u.addEventListener(d,y),{destroy:function(){u.removeEventListener(d,y)}}}function l(u,d,y){return Array.prototype.forEach.call(u,function(L){L.addEventListener(d,y)}),{destroy:function(){Array.prototype.forEach.call(u,function(L){L.removeEventListener(d,y)})}}}function f(u,d,y){return s(document.body,u,d,y)}o.exports=p},817:function(o){function n(i){var a;if(i.nodeName==="SELECT")i.focus(),a=i.value;else if(i.nodeName==="INPUT"||i.nodeName==="TEXTAREA"){var s=i.hasAttribute("readonly");s||i.setAttribute("readonly",""),i.select(),i.setSelectionRange(0,i.value.length),s||i.removeAttribute("readonly"),a=i.value}else{i.hasAttribute("contenteditable")&&i.focus();var p=window.getSelection(),c=document.createRange();c.selectNodeContents(i),p.removeAllRanges(),p.addRange(c),a=p.toString()}return a}o.exports=n},279:function(o){function n(){}n.prototype={on:function(i,a,s){var p=this.e||(this.e={});return(p[i]||(p[i]=[])).push({fn:a,ctx:s}),this},once:function(i,a,s){var p=this;function c(){p.off(i,c),a.apply(s,arguments)}return c._=a,this.on(i,c,s)},emit:function(i){var a=[].slice.call(arguments,1),s=((this.e||(this.e={}))[i]||[]).slice(),p=0,c=s.length;for(p;p0&&i[i.length-1])&&(c[0]===6||c[0]===2)){r=0;continue}if(c[0]===3&&(!i||c[1]>i[0]&&c[1]=e.length&&(e=void 0),{value:e&&e[o++],done:!e}}};throw new TypeError(t?"Object is not iterable.":"Symbol.iterator is not defined.")}function N(e,t){var r=typeof Symbol=="function"&&e[Symbol.iterator];if(!r)return e;var o=r.call(e),n,i=[],a;try{for(;(t===void 0||t-- >0)&&!(n=o.next()).done;)i.push(n.value)}catch(s){a={error:s}}finally{try{n&&!n.done&&(r=o.return)&&r.call(o)}finally{if(a)throw a.error}}return i}function q(e,t,r){if(r||arguments.length===2)for(var o=0,n=t.length,i;o1||p(d,L)})},y&&(n[d]=y(n[d])))}function p(d,y){try{c(o[d](y))}catch(L){u(i[0][3],L)}}function c(d){d.value instanceof nt?Promise.resolve(d.value.v).then(l,f):u(i[0][2],d)}function l(d){p("next",d)}function f(d){p("throw",d)}function u(d,y){d(y),i.shift(),i.length&&p(i[0][0],i[0][1])}}function uo(e){if(!Symbol.asyncIterator)throw new TypeError("Symbol.asyncIterator is not defined.");var t=e[Symbol.asyncIterator],r;return t?t.call(e):(e=typeof he=="function"?he(e):e[Symbol.iterator](),r={},o("next"),o("throw"),o("return"),r[Symbol.asyncIterator]=function(){return this},r);function o(i){r[i]=e[i]&&function(a){return new Promise(function(s,p){a=e[i](a),n(s,p,a.done,a.value)})}}function n(i,a,s,p){Promise.resolve(p).then(function(c){i({value:c,done:s})},a)}}function H(e){return typeof e=="function"}function ut(e){var t=function(o){Error.call(o),o.stack=new Error().stack},r=e(t);return r.prototype=Object.create(Error.prototype),r.prototype.constructor=r,r}var zt=ut(function(e){return function(r){e(this),this.message=r?r.length+` errors occurred during unsubscription: +`+r.map(function(o,n){return n+1+") "+o.toString()}).join(` + `):"",this.name="UnsubscriptionError",this.errors=r}});function Qe(e,t){if(e){var r=e.indexOf(t);0<=r&&e.splice(r,1)}}var Ue=function(){function e(t){this.initialTeardown=t,this.closed=!1,this._parentage=null,this._finalizers=null}return e.prototype.unsubscribe=function(){var t,r,o,n,i;if(!this.closed){this.closed=!0;var a=this._parentage;if(a)if(this._parentage=null,Array.isArray(a))try{for(var s=he(a),p=s.next();!p.done;p=s.next()){var c=p.value;c.remove(this)}}catch(L){t={error:L}}finally{try{p&&!p.done&&(r=s.return)&&r.call(s)}finally{if(t)throw t.error}}else a.remove(this);var l=this.initialTeardown;if(H(l))try{l()}catch(L){i=L instanceof zt?L.errors:[L]}var f=this._finalizers;if(f){this._finalizers=null;try{for(var u=he(f),d=u.next();!d.done;d=u.next()){var y=d.value;try{ho(y)}catch(L){i=i!=null?i:[],L instanceof zt?i=q(q([],N(i)),N(L.errors)):i.push(L)}}}catch(L){o={error:L}}finally{try{d&&!d.done&&(n=u.return)&&n.call(u)}finally{if(o)throw o.error}}}if(i)throw new zt(i)}},e.prototype.add=function(t){var r;if(t&&t!==this)if(this.closed)ho(t);else{if(t instanceof e){if(t.closed||t._hasParent(this))return;t._addParent(this)}(this._finalizers=(r=this._finalizers)!==null&&r!==void 0?r:[]).push(t)}},e.prototype._hasParent=function(t){var r=this._parentage;return r===t||Array.isArray(r)&&r.includes(t)},e.prototype._addParent=function(t){var r=this._parentage;this._parentage=Array.isArray(r)?(r.push(t),r):r?[r,t]:t},e.prototype._removeParent=function(t){var r=this._parentage;r===t?this._parentage=null:Array.isArray(r)&&Qe(r,t)},e.prototype.remove=function(t){var r=this._finalizers;r&&Qe(r,t),t instanceof e&&t._removeParent(this)},e.EMPTY=function(){var t=new e;return t.closed=!0,t}(),e}();var Tr=Ue.EMPTY;function qt(e){return e instanceof Ue||e&&"closed"in e&&H(e.remove)&&H(e.add)&&H(e.unsubscribe)}function ho(e){H(e)?e():e.unsubscribe()}var Pe={onUnhandledError:null,onStoppedNotification:null,Promise:void 0,useDeprecatedSynchronousErrorHandling:!1,useDeprecatedNextContext:!1};var dt={setTimeout:function(e,t){for(var r=[],o=2;o0},enumerable:!1,configurable:!0}),t.prototype._trySubscribe=function(r){return this._throwIfClosed(),e.prototype._trySubscribe.call(this,r)},t.prototype._subscribe=function(r){return this._throwIfClosed(),this._checkFinalizedStatuses(r),this._innerSubscribe(r)},t.prototype._innerSubscribe=function(r){var o=this,n=this,i=n.hasError,a=n.isStopped,s=n.observers;return i||a?Tr:(this.currentObservers=null,s.push(r),new Ue(function(){o.currentObservers=null,Qe(s,r)}))},t.prototype._checkFinalizedStatuses=function(r){var o=this,n=o.hasError,i=o.thrownError,a=o.isStopped;n?r.error(i):a&&r.complete()},t.prototype.asObservable=function(){var r=new j;return r.source=this,r},t.create=function(r,o){return new To(r,o)},t}(j);var To=function(e){oe(t,e);function t(r,o){var n=e.call(this)||this;return n.destination=r,n.source=o,n}return t.prototype.next=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.next)===null||n===void 0||n.call(o,r)},t.prototype.error=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.error)===null||n===void 0||n.call(o,r)},t.prototype.complete=function(){var r,o;(o=(r=this.destination)===null||r===void 0?void 0:r.complete)===null||o===void 0||o.call(r)},t.prototype._subscribe=function(r){var o,n;return(n=(o=this.source)===null||o===void 0?void 0:o.subscribe(r))!==null&&n!==void 0?n:Tr},t}(g);var _r=function(e){oe(t,e);function t(r){var o=e.call(this)||this;return o._value=r,o}return Object.defineProperty(t.prototype,"value",{get:function(){return this.getValue()},enumerable:!1,configurable:!0}),t.prototype._subscribe=function(r){var o=e.prototype._subscribe.call(this,r);return!o.closed&&r.next(this._value),o},t.prototype.getValue=function(){var r=this,o=r.hasError,n=r.thrownError,i=r._value;if(o)throw n;return this._throwIfClosed(),i},t.prototype.next=function(r){e.prototype.next.call(this,this._value=r)},t}(g);var At={now:function(){return(At.delegate||Date).now()},delegate:void 0};var Ct=function(e){oe(t,e);function t(r,o,n){r===void 0&&(r=1/0),o===void 0&&(o=1/0),n===void 0&&(n=At);var i=e.call(this)||this;return i._bufferSize=r,i._windowTime=o,i._timestampProvider=n,i._buffer=[],i._infiniteTimeWindow=!0,i._infiniteTimeWindow=o===1/0,i._bufferSize=Math.max(1,r),i._windowTime=Math.max(1,o),i}return t.prototype.next=function(r){var o=this,n=o.isStopped,i=o._buffer,a=o._infiniteTimeWindow,s=o._timestampProvider,p=o._windowTime;n||(i.push(r),!a&&i.push(s.now()+p)),this._trimBuffer(),e.prototype.next.call(this,r)},t.prototype._subscribe=function(r){this._throwIfClosed(),this._trimBuffer();for(var o=this._innerSubscribe(r),n=this,i=n._infiniteTimeWindow,a=n._buffer,s=a.slice(),p=0;p0?e.prototype.schedule.call(this,r,o):(this.delay=o,this.state=r,this.scheduler.flush(this),this)},t.prototype.execute=function(r,o){return o>0||this.closed?e.prototype.execute.call(this,r,o):this._execute(r,o)},t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!=null&&n>0||n==null&&this.delay>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.flush(this),0)},t}(gt);var Lo=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t}(yt);var kr=new Lo(Oo);var Mo=function(e){oe(t,e);function t(r,o){var n=e.call(this,r,o)||this;return n.scheduler=r,n.work=o,n}return t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!==null&&n>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.actions.push(this),r._scheduled||(r._scheduled=vt.requestAnimationFrame(function(){return r.flush(void 0)})))},t.prototype.recycleAsyncId=function(r,o,n){var i;if(n===void 0&&(n=0),n!=null?n>0:this.delay>0)return e.prototype.recycleAsyncId.call(this,r,o,n);var a=r.actions;o!=null&&((i=a[a.length-1])===null||i===void 0?void 0:i.id)!==o&&(vt.cancelAnimationFrame(o),r._scheduled=void 0)},t}(gt);var _o=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t.prototype.flush=function(r){this._active=!0;var o=this._scheduled;this._scheduled=void 0;var n=this.actions,i;r=r||n.shift();do if(i=r.execute(r.state,r.delay))break;while((r=n[0])&&r.id===o&&n.shift());if(this._active=!1,i){for(;(r=n[0])&&r.id===o&&n.shift();)r.unsubscribe();throw i}},t}(yt);var me=new _o(Mo);var S=new j(function(e){return e.complete()});function Yt(e){return e&&H(e.schedule)}function Hr(e){return e[e.length-1]}function Xe(e){return H(Hr(e))?e.pop():void 0}function ke(e){return Yt(Hr(e))?e.pop():void 0}function Bt(e,t){return typeof Hr(e)=="number"?e.pop():t}var xt=function(e){return e&&typeof e.length=="number"&&typeof e!="function"};function Gt(e){return H(e==null?void 0:e.then)}function Jt(e){return H(e[bt])}function Xt(e){return Symbol.asyncIterator&&H(e==null?void 0:e[Symbol.asyncIterator])}function Zt(e){return new TypeError("You provided "+(e!==null&&typeof e=="object"?"an invalid object":"'"+e+"'")+" where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.")}function Zi(){return typeof Symbol!="function"||!Symbol.iterator?"@@iterator":Symbol.iterator}var er=Zi();function tr(e){return H(e==null?void 0:e[er])}function rr(e){return fo(this,arguments,function(){var r,o,n,i;return Nt(this,function(a){switch(a.label){case 0:r=e.getReader(),a.label=1;case 1:a.trys.push([1,,9,10]),a.label=2;case 2:return[4,nt(r.read())];case 3:return o=a.sent(),n=o.value,i=o.done,i?[4,nt(void 0)]:[3,5];case 4:return[2,a.sent()];case 5:return[4,nt(n)];case 6:return[4,a.sent()];case 7:return a.sent(),[3,2];case 8:return[3,10];case 9:return r.releaseLock(),[7];case 10:return[2]}})})}function or(e){return H(e==null?void 0:e.getReader)}function U(e){if(e instanceof j)return e;if(e!=null){if(Jt(e))return ea(e);if(xt(e))return ta(e);if(Gt(e))return ra(e);if(Xt(e))return Ao(e);if(tr(e))return oa(e);if(or(e))return na(e)}throw Zt(e)}function ea(e){return new j(function(t){var r=e[bt]();if(H(r.subscribe))return r.subscribe(t);throw new TypeError("Provided object does not correctly implement Symbol.observable")})}function ta(e){return new j(function(t){for(var r=0;r=2;return function(o){return o.pipe(e?b(function(n,i){return e(n,i,o)}):le,Te(1),r?De(t):Qo(function(){return new ir}))}}function jr(e){return e<=0?function(){return S}:E(function(t,r){var o=[];t.subscribe(T(r,function(n){o.push(n),e=2,!0))}function pe(e){e===void 0&&(e={});var t=e.connector,r=t===void 0?function(){return new g}:t,o=e.resetOnError,n=o===void 0?!0:o,i=e.resetOnComplete,a=i===void 0?!0:i,s=e.resetOnRefCountZero,p=s===void 0?!0:s;return function(c){var l,f,u,d=0,y=!1,L=!1,X=function(){f==null||f.unsubscribe(),f=void 0},te=function(){X(),l=u=void 0,y=L=!1},J=function(){var k=l;te(),k==null||k.unsubscribe()};return E(function(k,ft){d++,!L&&!y&&X();var qe=u=u!=null?u:r();ft.add(function(){d--,d===0&&!L&&!y&&(f=Ur(J,p))}),qe.subscribe(ft),!l&&d>0&&(l=new at({next:function(Fe){return qe.next(Fe)},error:function(Fe){L=!0,X(),f=Ur(te,n,Fe),qe.error(Fe)},complete:function(){y=!0,X(),f=Ur(te,a),qe.complete()}}),U(k).subscribe(l))})(c)}}function Ur(e,t){for(var r=[],o=2;oe.next(document)),e}function P(e,t=document){return Array.from(t.querySelectorAll(e))}function R(e,t=document){let r=fe(e,t);if(typeof r=="undefined")throw new ReferenceError(`Missing element: expected "${e}" to be present`);return r}function fe(e,t=document){return t.querySelector(e)||void 0}function Ie(){var e,t,r,o;return(o=(r=(t=(e=document.activeElement)==null?void 0:e.shadowRoot)==null?void 0:t.activeElement)!=null?r:document.activeElement)!=null?o:void 0}var wa=O(h(document.body,"focusin"),h(document.body,"focusout")).pipe(_e(1),Q(void 0),m(()=>Ie()||document.body),G(1));function et(e){return wa.pipe(m(t=>e.contains(t)),K())}function $t(e,t){return C(()=>O(h(e,"mouseenter").pipe(m(()=>!0)),h(e,"mouseleave").pipe(m(()=>!1))).pipe(t?Ht(r=>Le(+!r*t)):le,Q(e.matches(":hover"))))}function Jo(e,t){if(typeof t=="string"||typeof t=="number")e.innerHTML+=t.toString();else if(t instanceof Node)e.appendChild(t);else if(Array.isArray(t))for(let r of t)Jo(e,r)}function x(e,t,...r){let o=document.createElement(e);if(t)for(let n of Object.keys(t))typeof t[n]!="undefined"&&(typeof t[n]!="boolean"?o.setAttribute(n,t[n]):o.setAttribute(n,""));for(let n of r)Jo(o,n);return o}function sr(e){if(e>999){let t=+((e-950)%1e3>99);return`${((e+1e-6)/1e3).toFixed(t)}k`}else return e.toString()}function Tt(e){let t=x("script",{src:e});return C(()=>(document.head.appendChild(t),O(h(t,"load"),h(t,"error").pipe(v(()=>$r(()=>new ReferenceError(`Invalid script: ${e}`))))).pipe(m(()=>{}),_(()=>document.head.removeChild(t)),Te(1))))}var Xo=new g,Ta=C(()=>typeof ResizeObserver=="undefined"?Tt("https://unpkg.com/resize-observer-polyfill"):I(void 0)).pipe(m(()=>new ResizeObserver(e=>e.forEach(t=>Xo.next(t)))),v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function ce(e){return{width:e.offsetWidth,height:e.offsetHeight}}function ge(e){let t=e;for(;t.clientWidth===0&&t.parentElement;)t=t.parentElement;return Ta.pipe(w(r=>r.observe(t)),v(r=>Xo.pipe(b(o=>o.target===t),_(()=>r.unobserve(t)))),m(()=>ce(e)),Q(ce(e)))}function St(e){return{width:e.scrollWidth,height:e.scrollHeight}}function cr(e){let t=e.parentElement;for(;t&&(e.scrollWidth<=t.scrollWidth&&e.scrollHeight<=t.scrollHeight);)t=(e=t).parentElement;return t?e:void 0}function Zo(e){let t=[],r=e.parentElement;for(;r;)(e.clientWidth>r.clientWidth||e.clientHeight>r.clientHeight)&&t.push(r),r=(e=r).parentElement;return t.length===0&&t.push(document.documentElement),t}function Ve(e){return{x:e.offsetLeft,y:e.offsetTop}}function en(e){let t=e.getBoundingClientRect();return{x:t.x+window.scrollX,y:t.y+window.scrollY}}function tn(e){return O(h(window,"load"),h(window,"resize")).pipe(Me(0,me),m(()=>Ve(e)),Q(Ve(e)))}function pr(e){return{x:e.scrollLeft,y:e.scrollTop}}function Ne(e){return O(h(e,"scroll"),h(window,"scroll"),h(window,"resize")).pipe(Me(0,me),m(()=>pr(e)),Q(pr(e)))}var rn=new g,Sa=C(()=>I(new IntersectionObserver(e=>{for(let t of e)rn.next(t)},{threshold:0}))).pipe(v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function tt(e){return Sa.pipe(w(t=>t.observe(e)),v(t=>rn.pipe(b(({target:r})=>r===e),_(()=>t.unobserve(e)),m(({isIntersecting:r})=>r))))}function on(e,t=16){return Ne(e).pipe(m(({y:r})=>{let o=ce(e),n=St(e);return r>=n.height-o.height-t}),K())}var lr={drawer:R("[data-md-toggle=drawer]"),search:R("[data-md-toggle=search]")};function nn(e){return lr[e].checked}function Je(e,t){lr[e].checked!==t&&lr[e].click()}function ze(e){let t=lr[e];return h(t,"change").pipe(m(()=>t.checked),Q(t.checked))}function Oa(e,t){switch(e.constructor){case HTMLInputElement:return e.type==="radio"?/^Arrow/.test(t):!0;case HTMLSelectElement:case HTMLTextAreaElement:return!0;default:return e.isContentEditable}}function La(){return O(h(window,"compositionstart").pipe(m(()=>!0)),h(window,"compositionend").pipe(m(()=>!1))).pipe(Q(!1))}function an(){let e=h(window,"keydown").pipe(b(t=>!(t.metaKey||t.ctrlKey)),m(t=>({mode:nn("search")?"search":"global",type:t.key,claim(){t.preventDefault(),t.stopPropagation()}})),b(({mode:t,type:r})=>{if(t==="global"){let o=Ie();if(typeof o!="undefined")return!Oa(o,r)}return!0}),pe());return La().pipe(v(t=>t?S:e))}function ye(){return new URL(location.href)}function lt(e,t=!1){if(B("navigation.instant")&&!t){let r=x("a",{href:e.href});document.body.appendChild(r),r.click(),r.remove()}else location.href=e.href}function sn(){return new g}function cn(){return location.hash.slice(1)}function pn(e){let t=x("a",{href:e});t.addEventListener("click",r=>r.stopPropagation()),t.click()}function Ma(e){return O(h(window,"hashchange"),e).pipe(m(cn),Q(cn()),b(t=>t.length>0),G(1))}function ln(e){return Ma(e).pipe(m(t=>fe(`[id="${t}"]`)),b(t=>typeof t!="undefined"))}function Pt(e){let t=matchMedia(e);return ar(r=>t.addListener(()=>r(t.matches))).pipe(Q(t.matches))}function mn(){let e=matchMedia("print");return O(h(window,"beforeprint").pipe(m(()=>!0)),h(window,"afterprint").pipe(m(()=>!1))).pipe(Q(e.matches))}function Nr(e,t){return e.pipe(v(r=>r?t():S))}function zr(e,t){return new j(r=>{let o=new XMLHttpRequest;return o.open("GET",`${e}`),o.responseType="blob",o.addEventListener("load",()=>{o.status>=200&&o.status<300?(r.next(o.response),r.complete()):r.error(new Error(o.statusText))}),o.addEventListener("error",()=>{r.error(new Error("Network error"))}),o.addEventListener("abort",()=>{r.complete()}),typeof(t==null?void 0:t.progress$)!="undefined"&&(o.addEventListener("progress",n=>{var i;if(n.lengthComputable)t.progress$.next(n.loaded/n.total*100);else{let a=(i=o.getResponseHeader("Content-Length"))!=null?i:0;t.progress$.next(n.loaded/+a*100)}}),t.progress$.next(5)),o.send(),()=>o.abort()})}function je(e,t){return zr(e,t).pipe(v(r=>r.text()),m(r=>JSON.parse(r)),G(1))}function fn(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/html")),G(1))}function un(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/xml")),G(1))}function dn(){return{x:Math.max(0,scrollX),y:Math.max(0,scrollY)}}function hn(){return O(h(window,"scroll",{passive:!0}),h(window,"resize",{passive:!0})).pipe(m(dn),Q(dn()))}function bn(){return{width:innerWidth,height:innerHeight}}function vn(){return h(window,"resize",{passive:!0}).pipe(m(bn),Q(bn()))}function gn(){return z([hn(),vn()]).pipe(m(([e,t])=>({offset:e,size:t})),G(1))}function mr(e,{viewport$:t,header$:r}){let o=t.pipe(ee("size")),n=z([o,r]).pipe(m(()=>Ve(e)));return z([r,t,n]).pipe(m(([{height:i},{offset:a,size:s},{x:p,y:c}])=>({offset:{x:a.x-p,y:a.y-c+i},size:s})))}function _a(e){return h(e,"message",t=>t.data)}function Aa(e){let t=new g;return t.subscribe(r=>e.postMessage(r)),t}function yn(e,t=new Worker(e)){let r=_a(t),o=Aa(t),n=new g;n.subscribe(o);let i=o.pipe(Z(),ie(!0));return n.pipe(Z(),Re(r.pipe(W(i))),pe())}var Ca=R("#__config"),Ot=JSON.parse(Ca.textContent);Ot.base=`${new URL(Ot.base,ye())}`;function xe(){return Ot}function B(e){return Ot.features.includes(e)}function Ee(e,t){return typeof t!="undefined"?Ot.translations[e].replace("#",t.toString()):Ot.translations[e]}function Se(e,t=document){return R(`[data-md-component=${e}]`,t)}function ae(e,t=document){return P(`[data-md-component=${e}]`,t)}function ka(e){let t=R(".md-typeset > :first-child",e);return h(t,"click",{once:!0}).pipe(m(()=>R(".md-typeset",e)),m(r=>({hash:__md_hash(r.innerHTML)})))}function xn(e){if(!B("announce.dismiss")||!e.childElementCount)return S;if(!e.hidden){let t=R(".md-typeset",e);__md_hash(t.innerHTML)===__md_get("__announce")&&(e.hidden=!0)}return C(()=>{let t=new g;return t.subscribe(({hash:r})=>{e.hidden=!0,__md_set("__announce",r)}),ka(e).pipe(w(r=>t.next(r)),_(()=>t.complete()),m(r=>$({ref:e},r)))})}function Ha(e,{target$:t}){return t.pipe(m(r=>({hidden:r!==e})))}function En(e,t){let r=new g;return r.subscribe(({hidden:o})=>{e.hidden=o}),Ha(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))}function Rt(e,t){return t==="inline"?x("div",{class:"md-tooltip md-tooltip--inline",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"})):x("div",{class:"md-tooltip",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"}))}function wn(...e){return x("div",{class:"md-tooltip2",role:"tooltip"},x("div",{class:"md-tooltip2__inner md-typeset"},e))}function Tn(e,t){if(t=t?`${t}_annotation_${e}`:void 0,t){let r=t?`#${t}`:void 0;return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("a",{href:r,class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}else return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("span",{class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}function Sn(e){return x("button",{class:"md-clipboard md-icon",title:Ee("clipboard.copy"),"data-clipboard-target":`#${e} > code`})}var Ln=Mt(qr());function Qr(e,t){let r=t&2,o=t&1,n=Object.keys(e.terms).filter(p=>!e.terms[p]).reduce((p,c)=>[...p,x("del",null,(0,Ln.default)(c))," "],[]).slice(0,-1),i=xe(),a=new URL(e.location,i.base);B("search.highlight")&&a.searchParams.set("h",Object.entries(e.terms).filter(([,p])=>p).reduce((p,[c])=>`${p} ${c}`.trim(),""));let{tags:s}=xe();return x("a",{href:`${a}`,class:"md-search-result__link",tabIndex:-1},x("article",{class:"md-search-result__article md-typeset","data-md-score":e.score.toFixed(2)},r>0&&x("div",{class:"md-search-result__icon md-icon"}),r>0&&x("h1",null,e.title),r<=0&&x("h2",null,e.title),o>0&&e.text.length>0&&e.text,e.tags&&x("nav",{class:"md-tags"},e.tags.map(p=>{let c=s?p in s?`md-tag-icon md-tag--${s[p]}`:"md-tag-icon":"";return x("span",{class:`md-tag ${c}`},p)})),o>0&&n.length>0&&x("p",{class:"md-search-result__terms"},Ee("search.result.term.missing"),": ",...n)))}function Mn(e){let t=e[0].score,r=[...e],o=xe(),n=r.findIndex(l=>!`${new URL(l.location,o.base)}`.includes("#")),[i]=r.splice(n,1),a=r.findIndex(l=>l.scoreQr(l,1)),...p.length?[x("details",{class:"md-search-result__more"},x("summary",{tabIndex:-1},x("div",null,p.length>0&&p.length===1?Ee("search.result.more.one"):Ee("search.result.more.other",p.length))),...p.map(l=>Qr(l,1)))]:[]];return x("li",{class:"md-search-result__item"},c)}function _n(e){return x("ul",{class:"md-source__facts"},Object.entries(e).map(([t,r])=>x("li",{class:`md-source__fact md-source__fact--${t}`},typeof r=="number"?sr(r):r)))}function Kr(e){let t=`tabbed-control tabbed-control--${e}`;return x("div",{class:t,hidden:!0},x("button",{class:"tabbed-button",tabIndex:-1,"aria-hidden":"true"}))}function An(e){return x("div",{class:"md-typeset__scrollwrap"},x("div",{class:"md-typeset__table"},e))}function Ra(e){var o;let t=xe(),r=new URL(`../${e.version}/`,t.base);return x("li",{class:"md-version__item"},x("a",{href:`${r}`,class:"md-version__link"},e.title,((o=t.version)==null?void 0:o.alias)&&e.aliases.length>0&&x("span",{class:"md-version__alias"},e.aliases[0])))}function Cn(e,t){var o;let r=xe();return e=e.filter(n=>{var i;return!((i=n.properties)!=null&&i.hidden)}),x("div",{class:"md-version"},x("button",{class:"md-version__current","aria-label":Ee("select.version")},t.title,((o=r.version)==null?void 0:o.alias)&&t.aliases.length>0&&x("span",{class:"md-version__alias"},t.aliases[0])),x("ul",{class:"md-version__list"},e.map(Ra)))}var Ia=0;function ja(e){let t=z([et(e),$t(e)]).pipe(m(([o,n])=>o||n),K()),r=C(()=>Zo(e)).pipe(ne(Ne),pt(1),He(t),m(()=>en(e)));return t.pipe(Ae(o=>o),v(()=>z([t,r])),m(([o,n])=>({active:o,offset:n})),pe())}function Fa(e,t){let{content$:r,viewport$:o}=t,n=`__tooltip2_${Ia++}`;return C(()=>{let i=new g,a=new _r(!1);i.pipe(Z(),ie(!1)).subscribe(a);let s=a.pipe(Ht(c=>Le(+!c*250,kr)),K(),v(c=>c?r:S),w(c=>c.id=n),pe());z([i.pipe(m(({active:c})=>c)),s.pipe(v(c=>$t(c,250)),Q(!1))]).pipe(m(c=>c.some(l=>l))).subscribe(a);let p=a.pipe(b(c=>c),re(s,o),m(([c,l,{size:f}])=>{let u=e.getBoundingClientRect(),d=u.width/2;if(l.role==="tooltip")return{x:d,y:8+u.height};if(u.y>=f.height/2){let{height:y}=ce(l);return{x:d,y:-16-y}}else return{x:d,y:16+u.height}}));return z([s,i,p]).subscribe(([c,{offset:l},f])=>{c.style.setProperty("--md-tooltip-host-x",`${l.x}px`),c.style.setProperty("--md-tooltip-host-y",`${l.y}px`),c.style.setProperty("--md-tooltip-x",`${f.x}px`),c.style.setProperty("--md-tooltip-y",`${f.y}px`),c.classList.toggle("md-tooltip2--top",f.y<0),c.classList.toggle("md-tooltip2--bottom",f.y>=0)}),a.pipe(b(c=>c),re(s,(c,l)=>l),b(c=>c.role==="tooltip")).subscribe(c=>{let l=ce(R(":scope > *",c));c.style.setProperty("--md-tooltip-width",`${l.width}px`),c.style.setProperty("--md-tooltip-tail","0px")}),a.pipe(K(),ve(me),re(s)).subscribe(([c,l])=>{l.classList.toggle("md-tooltip2--active",c)}),z([a.pipe(b(c=>c)),s]).subscribe(([c,l])=>{l.role==="dialog"?(e.setAttribute("aria-controls",n),e.setAttribute("aria-haspopup","dialog")):e.setAttribute("aria-describedby",n)}),a.pipe(b(c=>!c)).subscribe(()=>{e.removeAttribute("aria-controls"),e.removeAttribute("aria-describedby"),e.removeAttribute("aria-haspopup")}),ja(e).pipe(w(c=>i.next(c)),_(()=>i.complete()),m(c=>$({ref:e},c)))})}function mt(e,{viewport$:t},r=document.body){return Fa(e,{content$:new j(o=>{let n=e.title,i=wn(n);return o.next(i),e.removeAttribute("title"),r.append(i),()=>{i.remove(),e.setAttribute("title",n)}}),viewport$:t})}function Ua(e,t){let r=C(()=>z([tn(e),Ne(t)])).pipe(m(([{x:o,y:n},i])=>{let{width:a,height:s}=ce(e);return{x:o-i.x+a/2,y:n-i.y+s/2}}));return et(e).pipe(v(o=>r.pipe(m(n=>({active:o,offset:n})),Te(+!o||1/0))))}function kn(e,t,{target$:r}){let[o,n]=Array.from(e.children);return C(()=>{let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({offset:s}){e.style.setProperty("--md-tooltip-x",`${s.x}px`),e.style.setProperty("--md-tooltip-y",`${s.y}px`)},complete(){e.style.removeProperty("--md-tooltip-x"),e.style.removeProperty("--md-tooltip-y")}}),tt(e).pipe(W(a)).subscribe(s=>{e.toggleAttribute("data-md-visible",s)}),O(i.pipe(b(({active:s})=>s)),i.pipe(_e(250),b(({active:s})=>!s))).subscribe({next({active:s}){s?e.prepend(o):o.remove()},complete(){e.prepend(o)}}),i.pipe(Me(16,me)).subscribe(({active:s})=>{o.classList.toggle("md-tooltip--active",s)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:s})=>s)).subscribe({next(s){s?e.style.setProperty("--md-tooltip-0",`${-s}px`):e.style.removeProperty("--md-tooltip-0")},complete(){e.style.removeProperty("--md-tooltip-0")}}),h(n,"click").pipe(W(a),b(s=>!(s.metaKey||s.ctrlKey))).subscribe(s=>{s.stopPropagation(),s.preventDefault()}),h(n,"mousedown").pipe(W(a),re(i)).subscribe(([s,{active:p}])=>{var c;if(s.button!==0||s.metaKey||s.ctrlKey)s.preventDefault();else if(p){s.preventDefault();let l=e.parentElement.closest(".md-annotation");l instanceof HTMLElement?l.focus():(c=Ie())==null||c.blur()}}),r.pipe(W(a),b(s=>s===o),Ge(125)).subscribe(()=>e.focus()),Ua(e,t).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function Wa(e){return e.tagName==="CODE"?P(".c, .c1, .cm",e):[e]}function Da(e){let t=[];for(let r of Wa(e)){let o=[],n=document.createNodeIterator(r,NodeFilter.SHOW_TEXT);for(let i=n.nextNode();i;i=n.nextNode())o.push(i);for(let i of o){let a;for(;a=/(\(\d+\))(!)?/.exec(i.textContent);){let[,s,p]=a;if(typeof p=="undefined"){let c=i.splitText(a.index);i=c.splitText(s.length),t.push(c)}else{i.textContent=s,t.push(i);break}}}}return t}function Hn(e,t){t.append(...Array.from(e.childNodes))}function fr(e,t,{target$:r,print$:o}){let n=t.closest("[id]"),i=n==null?void 0:n.id,a=new Map;for(let s of Da(t)){let[,p]=s.textContent.match(/\((\d+)\)/);fe(`:scope > li:nth-child(${p})`,e)&&(a.set(p,Tn(p,i)),s.replaceWith(a.get(p)))}return a.size===0?S:C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=[];for(let[l,f]of a)c.push([R(".md-typeset",f),R(`:scope > li:nth-child(${l})`,e)]);return o.pipe(W(p)).subscribe(l=>{e.hidden=!l,e.classList.toggle("md-annotation-list",l);for(let[f,u]of c)l?Hn(f,u):Hn(u,f)}),O(...[...a].map(([,l])=>kn(l,t,{target$:r}))).pipe(_(()=>s.complete()),pe())})}function $n(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return $n(t)}}function Pn(e,t){return C(()=>{let r=$n(e);return typeof r!="undefined"?fr(r,e,t):S})}var Rn=Mt(Br());var Va=0;function In(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return In(t)}}function Na(e){return ge(e).pipe(m(({width:t})=>({scrollable:St(e).width>t})),ee("scrollable"))}function jn(e,t){let{matches:r}=matchMedia("(hover)"),o=C(()=>{let n=new g,i=n.pipe(jr(1));n.subscribe(({scrollable:c})=>{c&&r?e.setAttribute("tabindex","0"):e.removeAttribute("tabindex")});let a=[];if(Rn.default.isSupported()&&(e.closest(".copy")||B("content.code.copy")&&!e.closest(".no-copy"))){let c=e.closest("pre");c.id=`__code_${Va++}`;let l=Sn(c.id);c.insertBefore(l,e),B("content.tooltips")&&a.push(mt(l,{viewport$}))}let s=e.closest(".highlight");if(s instanceof HTMLElement){let c=In(s);if(typeof c!="undefined"&&(s.classList.contains("annotate")||B("content.code.annotate"))){let l=fr(c,e,t);a.push(ge(s).pipe(W(i),m(({width:f,height:u})=>f&&u),K(),v(f=>f?l:S)))}}return P(":scope > span[id]",e).length&&e.classList.add("md-code__content"),Na(e).pipe(w(c=>n.next(c)),_(()=>n.complete()),m(c=>$({ref:e},c)),Re(...a))});return B("content.lazy")?tt(e).pipe(b(n=>n),Te(1),v(()=>o)):o}function za(e,{target$:t,print$:r}){let o=!0;return O(t.pipe(m(n=>n.closest("details:not([open])")),b(n=>e===n),m(()=>({action:"open",reveal:!0}))),r.pipe(b(n=>n||!o),w(()=>o=e.open),m(n=>({action:n?"open":"close"}))))}function Fn(e,t){return C(()=>{let r=new g;return r.subscribe(({action:o,reveal:n})=>{e.toggleAttribute("open",o==="open"),n&&e.scrollIntoView()}),za(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}var Un=".node circle,.node ellipse,.node path,.node polygon,.node rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}marker{fill:var(--md-mermaid-edge-color)!important}.edgeLabel .label rect{fill:#0000}.label{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.label foreignObject{line-height:normal;overflow:visible}.label div .edgeLabel{color:var(--md-mermaid-label-fg-color)}.edgeLabel,.edgeLabel p,.label div .edgeLabel{background-color:var(--md-mermaid-label-bg-color)}.edgeLabel,.edgeLabel p{fill:var(--md-mermaid-label-bg-color);color:var(--md-mermaid-edge-color)}.edgePath .path,.flowchart-link{stroke:var(--md-mermaid-edge-color);stroke-width:.05rem}.edgePath .arrowheadPath{fill:var(--md-mermaid-edge-color);stroke:none}.cluster rect{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}.cluster span{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}g #flowchart-circleEnd,g #flowchart-circleStart,g #flowchart-crossEnd,g #flowchart-crossStart,g #flowchart-pointEnd,g #flowchart-pointStart{stroke:none}g.classGroup line,g.classGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.classGroup text{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.classLabel .box{fill:var(--md-mermaid-label-bg-color);background-color:var(--md-mermaid-label-bg-color);opacity:1}.classLabel .label{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.node .divider{stroke:var(--md-mermaid-node-fg-color)}.relation{stroke:var(--md-mermaid-edge-color)}.cardinality{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.cardinality text{fill:inherit!important}defs #classDiagram-compositionEnd,defs #classDiagram-compositionStart,defs #classDiagram-dependencyEnd,defs #classDiagram-dependencyStart,defs #classDiagram-extensionEnd,defs #classDiagram-extensionStart{fill:var(--md-mermaid-edge-color)!important;stroke:var(--md-mermaid-edge-color)!important}defs #classDiagram-aggregationEnd,defs #classDiagram-aggregationStart{fill:var(--md-mermaid-label-bg-color)!important;stroke:var(--md-mermaid-edge-color)!important}g.stateGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.stateGroup .state-title{fill:var(--md-mermaid-label-fg-color)!important;font-family:var(--md-mermaid-font-family)}g.stateGroup .composit{fill:var(--md-mermaid-label-bg-color)}.nodeLabel,.nodeLabel p{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}a .nodeLabel{text-decoration:underline}.node circle.state-end,.node circle.state-start,.start-state{fill:var(--md-mermaid-edge-color);stroke:none}.end-state-inner,.end-state-outer{fill:var(--md-mermaid-edge-color)}.end-state-inner,.node circle.state-end{stroke:var(--md-mermaid-label-bg-color)}.transition{stroke:var(--md-mermaid-edge-color)}[id^=state-fork] rect,[id^=state-join] rect{fill:var(--md-mermaid-edge-color)!important;stroke:none!important}.statediagram-cluster.statediagram-cluster .inner{fill:var(--md-default-bg-color)}.statediagram-cluster rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.statediagram-state rect.divider{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}defs #statediagram-barbEnd{stroke:var(--md-mermaid-edge-color)}.attributeBoxEven,.attributeBoxOdd{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityBox{fill:var(--md-mermaid-label-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityLabel{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.relationshipLabelBox{fill:var(--md-mermaid-label-bg-color);fill-opacity:1;background-color:var(--md-mermaid-label-bg-color);opacity:1}.relationshipLabel{fill:var(--md-mermaid-label-fg-color)}.relationshipLine{stroke:var(--md-mermaid-edge-color)}defs #ONE_OR_MORE_END *,defs #ONE_OR_MORE_START *,defs #ONLY_ONE_END *,defs #ONLY_ONE_START *,defs #ZERO_OR_MORE_END *,defs #ZERO_OR_MORE_START *,defs #ZERO_OR_ONE_END *,defs #ZERO_OR_ONE_START *{stroke:var(--md-mermaid-edge-color)!important}defs #ZERO_OR_MORE_END circle,defs #ZERO_OR_MORE_START circle{fill:var(--md-mermaid-label-bg-color)}.actor{fill:var(--md-mermaid-sequence-actor-bg-color);stroke:var(--md-mermaid-sequence-actor-border-color)}text.actor>tspan{fill:var(--md-mermaid-sequence-actor-fg-color);font-family:var(--md-mermaid-font-family)}line{stroke:var(--md-mermaid-sequence-actor-line-color)}.actor-man circle,.actor-man line{fill:var(--md-mermaid-sequence-actorman-bg-color);stroke:var(--md-mermaid-sequence-actorman-line-color)}.messageLine0,.messageLine1{stroke:var(--md-mermaid-sequence-message-line-color)}.note{fill:var(--md-mermaid-sequence-note-bg-color);stroke:var(--md-mermaid-sequence-note-border-color)}.loopText,.loopText>tspan,.messageText,.noteText>tspan{stroke:none;font-family:var(--md-mermaid-font-family)!important}.messageText{fill:var(--md-mermaid-sequence-message-fg-color)}.loopText,.loopText>tspan{fill:var(--md-mermaid-sequence-loop-fg-color)}.noteText>tspan{fill:var(--md-mermaid-sequence-note-fg-color)}#arrowhead path{fill:var(--md-mermaid-sequence-message-line-color);stroke:none}.loopLine{fill:var(--md-mermaid-sequence-loop-bg-color);stroke:var(--md-mermaid-sequence-loop-border-color)}.labelBox{fill:var(--md-mermaid-sequence-label-bg-color);stroke:none}.labelText,.labelText>span{fill:var(--md-mermaid-sequence-label-fg-color);font-family:var(--md-mermaid-font-family)}.sequenceNumber{fill:var(--md-mermaid-sequence-number-fg-color)}rect.rect{fill:var(--md-mermaid-sequence-box-bg-color);stroke:none}rect.rect+text.text{fill:var(--md-mermaid-sequence-box-fg-color)}defs #sequencenumber{fill:var(--md-mermaid-sequence-number-bg-color)!important}";var Gr,Qa=0;function Ka(){return typeof mermaid=="undefined"||mermaid instanceof Element?Tt("https://unpkg.com/mermaid@11/dist/mermaid.min.js"):I(void 0)}function Wn(e){return e.classList.remove("mermaid"),Gr||(Gr=Ka().pipe(w(()=>mermaid.initialize({startOnLoad:!1,themeCSS:Un,sequence:{actorFontSize:"16px",messageFontSize:"16px",noteFontSize:"16px"}})),m(()=>{}),G(1))),Gr.subscribe(()=>co(this,null,function*(){e.classList.add("mermaid");let t=`__mermaid_${Qa++}`,r=x("div",{class:"mermaid"}),o=e.textContent,{svg:n,fn:i}=yield mermaid.render(t,o),a=r.attachShadow({mode:"closed"});a.innerHTML=n,e.replaceWith(r),i==null||i(a)})),Gr.pipe(m(()=>({ref:e})))}var Dn=x("table");function Vn(e){return e.replaceWith(Dn),Dn.replaceWith(An(e)),I({ref:e})}function Ya(e){let t=e.find(r=>r.checked)||e[0];return O(...e.map(r=>h(r,"change").pipe(m(()=>R(`label[for="${r.id}"]`))))).pipe(Q(R(`label[for="${t.id}"]`)),m(r=>({active:r})))}function Nn(e,{viewport$:t,target$:r}){let o=R(".tabbed-labels",e),n=P(":scope > input",e),i=Kr("prev");e.append(i);let a=Kr("next");return e.append(a),C(()=>{let s=new g,p=s.pipe(Z(),ie(!0));z([s,ge(e),tt(e)]).pipe(W(p),Me(1,me)).subscribe({next([{active:c},l]){let f=Ve(c),{width:u}=ce(c);e.style.setProperty("--md-indicator-x",`${f.x}px`),e.style.setProperty("--md-indicator-width",`${u}px`);let d=pr(o);(f.xd.x+l.width)&&o.scrollTo({left:Math.max(0,f.x-16),behavior:"smooth"})},complete(){e.style.removeProperty("--md-indicator-x"),e.style.removeProperty("--md-indicator-width")}}),z([Ne(o),ge(o)]).pipe(W(p)).subscribe(([c,l])=>{let f=St(o);i.hidden=c.x<16,a.hidden=c.x>f.width-l.width-16}),O(h(i,"click").pipe(m(()=>-1)),h(a,"click").pipe(m(()=>1))).pipe(W(p)).subscribe(c=>{let{width:l}=ce(o);o.scrollBy({left:l*c,behavior:"smooth"})}),r.pipe(W(p),b(c=>n.includes(c))).subscribe(c=>c.click()),o.classList.add("tabbed-labels--linked");for(let c of n){let l=R(`label[for="${c.id}"]`);l.replaceChildren(x("a",{href:`#${l.htmlFor}`,tabIndex:-1},...Array.from(l.childNodes))),h(l.firstElementChild,"click").pipe(W(p),b(f=>!(f.metaKey||f.ctrlKey)),w(f=>{f.preventDefault(),f.stopPropagation()})).subscribe(()=>{history.replaceState({},"",`#${l.htmlFor}`),l.click()})}return B("content.tabs.link")&&s.pipe(Ce(1),re(t)).subscribe(([{active:c},{offset:l}])=>{let f=c.innerText.trim();if(c.hasAttribute("data-md-switching"))c.removeAttribute("data-md-switching");else{let u=e.offsetTop-l.y;for(let y of P("[data-tabs]"))for(let L of P(":scope > input",y)){let X=R(`label[for="${L.id}"]`);if(X!==c&&X.innerText.trim()===f){X.setAttribute("data-md-switching",""),L.click();break}}window.scrollTo({top:e.offsetTop-u});let d=__md_get("__tabs")||[];__md_set("__tabs",[...new Set([f,...d])])}}),s.pipe(W(p)).subscribe(()=>{for(let c of P("audio, video",e))c.pause()}),Ya(n).pipe(w(c=>s.next(c)),_(()=>s.complete()),m(c=>$({ref:e},c)))}).pipe(Ke(se))}function zn(e,{viewport$:t,target$:r,print$:o}){return O(...P(".annotate:not(.highlight)",e).map(n=>Pn(n,{target$:r,print$:o})),...P("pre:not(.mermaid) > code",e).map(n=>jn(n,{target$:r,print$:o})),...P("pre.mermaid",e).map(n=>Wn(n)),...P("table:not([class])",e).map(n=>Vn(n)),...P("details",e).map(n=>Fn(n,{target$:r,print$:o})),...P("[data-tabs]",e).map(n=>Nn(n,{viewport$:t,target$:r})),...P("[title]",e).filter(()=>B("content.tooltips")).map(n=>mt(n,{viewport$:t})))}function Ba(e,{alert$:t}){return t.pipe(v(r=>O(I(!0),I(!1).pipe(Ge(2e3))).pipe(m(o=>({message:r,active:o})))))}function qn(e,t){let r=R(".md-typeset",e);return C(()=>{let o=new g;return o.subscribe(({message:n,active:i})=>{e.classList.toggle("md-dialog--active",i),r.textContent=n}),Ba(e,t).pipe(w(n=>o.next(n)),_(()=>o.complete()),m(n=>$({ref:e},n)))})}var Ga=0;function Ja(e,t){document.body.append(e);let{width:r}=ce(e);e.style.setProperty("--md-tooltip-width",`${r}px`),e.remove();let o=cr(t),n=typeof o!="undefined"?Ne(o):I({x:0,y:0}),i=O(et(t),$t(t)).pipe(K());return z([i,n]).pipe(m(([a,s])=>{let{x:p,y:c}=Ve(t),l=ce(t),f=t.closest("table");return f&&t.parentElement&&(p+=f.offsetLeft+t.parentElement.offsetLeft,c+=f.offsetTop+t.parentElement.offsetTop),{active:a,offset:{x:p-s.x+l.width/2-r/2,y:c-s.y+l.height+8}}}))}function Qn(e){let t=e.title;if(!t.length)return S;let r=`__tooltip_${Ga++}`,o=Rt(r,"inline"),n=R(".md-typeset",o);return n.innerHTML=t,C(()=>{let i=new g;return i.subscribe({next({offset:a}){o.style.setProperty("--md-tooltip-x",`${a.x}px`),o.style.setProperty("--md-tooltip-y",`${a.y}px`)},complete(){o.style.removeProperty("--md-tooltip-x"),o.style.removeProperty("--md-tooltip-y")}}),O(i.pipe(b(({active:a})=>a)),i.pipe(_e(250),b(({active:a})=>!a))).subscribe({next({active:a}){a?(e.insertAdjacentElement("afterend",o),e.setAttribute("aria-describedby",r),e.removeAttribute("title")):(o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t))},complete(){o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t)}}),i.pipe(Me(16,me)).subscribe(({active:a})=>{o.classList.toggle("md-tooltip--active",a)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:a})=>a)).subscribe({next(a){a?o.style.setProperty("--md-tooltip-0",`${-a}px`):o.style.removeProperty("--md-tooltip-0")},complete(){o.style.removeProperty("--md-tooltip-0")}}),Ja(o,e).pipe(w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))}).pipe(Ke(se))}function Xa({viewport$:e}){if(!B("header.autohide"))return I(!1);let t=e.pipe(m(({offset:{y:n}})=>n),Be(2,1),m(([n,i])=>[nMath.abs(i-n.y)>100),m(([,[n]])=>n),K()),o=ze("search");return z([e,o]).pipe(m(([{offset:n},i])=>n.y>400&&!i),K(),v(n=>n?r:I(!1)),Q(!1))}function Kn(e,t){return C(()=>z([ge(e),Xa(t)])).pipe(m(([{height:r},o])=>({height:r,hidden:o})),K((r,o)=>r.height===o.height&&r.hidden===o.hidden),G(1))}function Yn(e,{header$:t,main$:r}){return C(()=>{let o=new g,n=o.pipe(Z(),ie(!0));o.pipe(ee("active"),He(t)).subscribe(([{active:a},{hidden:s}])=>{e.classList.toggle("md-header--shadow",a&&!s),e.hidden=s});let i=ue(P("[title]",e)).pipe(b(()=>B("content.tooltips")),ne(a=>Qn(a)));return r.subscribe(o),t.pipe(W(n),m(a=>$({ref:e},a)),Re(i.pipe(W(n))))})}function Za(e,{viewport$:t,header$:r}){return mr(e,{viewport$:t,header$:r}).pipe(m(({offset:{y:o}})=>{let{height:n}=ce(e);return{active:o>=n}}),ee("active"))}function Bn(e,t){return C(()=>{let r=new g;r.subscribe({next({active:n}){e.classList.toggle("md-header__title--active",n)},complete(){e.classList.remove("md-header__title--active")}});let o=fe(".md-content h1");return typeof o=="undefined"?S:Za(o,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))})}function Gn(e,{viewport$:t,header$:r}){let o=r.pipe(m(({height:i})=>i),K()),n=o.pipe(v(()=>ge(e).pipe(m(({height:i})=>({top:e.offsetTop,bottom:e.offsetTop+i})),ee("bottom"))));return z([o,n,t]).pipe(m(([i,{top:a,bottom:s},{offset:{y:p},size:{height:c}}])=>(c=Math.max(0,c-Math.max(0,a-p,i)-Math.max(0,c+p-s)),{offset:a-i,height:c,active:a-i<=p})),K((i,a)=>i.offset===a.offset&&i.height===a.height&&i.active===a.active))}function es(e){let t=__md_get("__palette")||{index:e.findIndex(o=>matchMedia(o.getAttribute("data-md-color-media")).matches)},r=Math.max(0,Math.min(t.index,e.length-1));return I(...e).pipe(ne(o=>h(o,"change").pipe(m(()=>o))),Q(e[r]),m(o=>({index:e.indexOf(o),color:{media:o.getAttribute("data-md-color-media"),scheme:o.getAttribute("data-md-color-scheme"),primary:o.getAttribute("data-md-color-primary"),accent:o.getAttribute("data-md-color-accent")}})),G(1))}function Jn(e){let t=P("input",e),r=x("meta",{name:"theme-color"});document.head.appendChild(r);let o=x("meta",{name:"color-scheme"});document.head.appendChild(o);let n=Pt("(prefers-color-scheme: light)");return C(()=>{let i=new g;return i.subscribe(a=>{if(document.body.setAttribute("data-md-color-switching",""),a.color.media==="(prefers-color-scheme)"){let s=matchMedia("(prefers-color-scheme: light)"),p=document.querySelector(s.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");a.color.scheme=p.getAttribute("data-md-color-scheme"),a.color.primary=p.getAttribute("data-md-color-primary"),a.color.accent=p.getAttribute("data-md-color-accent")}for(let[s,p]of Object.entries(a.color))document.body.setAttribute(`data-md-color-${s}`,p);for(let s=0;sa.key==="Enter"),re(i,(a,s)=>s)).subscribe(({index:a})=>{a=(a+1)%t.length,t[a].click(),t[a].focus()}),i.pipe(m(()=>{let a=Se("header"),s=window.getComputedStyle(a);return o.content=s.colorScheme,s.backgroundColor.match(/\d+/g).map(p=>(+p).toString(16).padStart(2,"0")).join("")})).subscribe(a=>r.content=`#${a}`),i.pipe(ve(se)).subscribe(()=>{document.body.removeAttribute("data-md-color-switching")}),es(t).pipe(W(n.pipe(Ce(1))),ct(),w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))})}function Xn(e,{progress$:t}){return C(()=>{let r=new g;return r.subscribe(({value:o})=>{e.style.setProperty("--md-progress-value",`${o}`)}),t.pipe(w(o=>r.next({value:o})),_(()=>r.complete()),m(o=>({ref:e,value:o})))})}var Jr=Mt(Br());function ts(e){e.setAttribute("data-md-copying","");let t=e.closest("[data-copy]"),r=t?t.getAttribute("data-copy"):e.innerText;return e.removeAttribute("data-md-copying"),r.trimEnd()}function Zn({alert$:e}){Jr.default.isSupported()&&new j(t=>{new Jr.default("[data-clipboard-target], [data-clipboard-text]",{text:r=>r.getAttribute("data-clipboard-text")||ts(R(r.getAttribute("data-clipboard-target")))}).on("success",r=>t.next(r))}).pipe(w(t=>{t.trigger.focus()}),m(()=>Ee("clipboard.copied"))).subscribe(e)}function ei(e,t){return e.protocol=t.protocol,e.hostname=t.hostname,e}function rs(e,t){let r=new Map;for(let o of P("url",e)){let n=R("loc",o),i=[ei(new URL(n.textContent),t)];r.set(`${i[0]}`,i);for(let a of P("[rel=alternate]",o)){let s=a.getAttribute("href");s!=null&&i.push(ei(new URL(s),t))}}return r}function ur(e){return un(new URL("sitemap.xml",e)).pipe(m(t=>rs(t,new URL(e))),de(()=>I(new Map)))}function os(e,t){if(!(e.target instanceof Element))return S;let r=e.target.closest("a");if(r===null)return S;if(r.target||e.metaKey||e.ctrlKey)return S;let o=new URL(r.href);return o.search=o.hash="",t.has(`${o}`)?(e.preventDefault(),I(new URL(r.href))):S}function ti(e){let t=new Map;for(let r of P(":scope > *",e.head))t.set(r.outerHTML,r);return t}function ri(e){for(let t of P("[href], [src]",e))for(let r of["href","src"]){let o=t.getAttribute(r);if(o&&!/^(?:[a-z]+:)?\/\//i.test(o)){t[r]=t[r];break}}return I(e)}function ns(e){for(let o of["[data-md-component=announce]","[data-md-component=container]","[data-md-component=header-topic]","[data-md-component=outdated]","[data-md-component=logo]","[data-md-component=skip]",...B("navigation.tabs.sticky")?["[data-md-component=tabs]"]:[]]){let n=fe(o),i=fe(o,e);typeof n!="undefined"&&typeof i!="undefined"&&n.replaceWith(i)}let t=ti(document);for(let[o,n]of ti(e))t.has(o)?t.delete(o):document.head.appendChild(n);for(let o of t.values()){let n=o.getAttribute("name");n!=="theme-color"&&n!=="color-scheme"&&o.remove()}let r=Se("container");return We(P("script",r)).pipe(v(o=>{let n=e.createElement("script");if(o.src){for(let i of o.getAttributeNames())n.setAttribute(i,o.getAttribute(i));return o.replaceWith(n),new j(i=>{n.onload=()=>i.complete()})}else return n.textContent=o.textContent,o.replaceWith(n),S}),Z(),ie(document))}function oi({location$:e,viewport$:t,progress$:r}){let o=xe();if(location.protocol==="file:")return S;let n=ur(o.base);I(document).subscribe(ri);let i=h(document.body,"click").pipe(He(n),v(([p,c])=>os(p,c)),pe()),a=h(window,"popstate").pipe(m(ye),pe());i.pipe(re(t)).subscribe(([p,{offset:c}])=>{history.replaceState(c,""),history.pushState(null,"",p)}),O(i,a).subscribe(e);let s=e.pipe(ee("pathname"),v(p=>fn(p,{progress$:r}).pipe(de(()=>(lt(p,!0),S)))),v(ri),v(ns),pe());return O(s.pipe(re(e,(p,c)=>c)),s.pipe(v(()=>e),ee("pathname"),v(()=>e),ee("hash")),e.pipe(K((p,c)=>p.pathname===c.pathname&&p.hash===c.hash),v(()=>i),w(()=>history.back()))).subscribe(p=>{var c,l;history.state!==null||!p.hash?window.scrollTo(0,(l=(c=history.state)==null?void 0:c.y)!=null?l:0):(history.scrollRestoration="auto",pn(p.hash),history.scrollRestoration="manual")}),e.subscribe(()=>{history.scrollRestoration="manual"}),h(window,"beforeunload").subscribe(()=>{history.scrollRestoration="auto"}),t.pipe(ee("offset"),_e(100)).subscribe(({offset:p})=>{history.replaceState(p,"")}),s}var ni=Mt(qr());function ii(e){let t=e.separator.split("|").map(n=>n.replace(/(\(\?[!=<][^)]+\))/g,"").length===0?"\uFFFD":n).join("|"),r=new RegExp(t,"img"),o=(n,i,a)=>`${i}${a}`;return n=>{n=n.replace(/[\s*+\-:~^]+/g," ").trim();let i=new RegExp(`(^|${e.separator}|)(${n.replace(/[|\\{}()[\]^$+*?.-]/g,"\\$&").replace(r,"|")})`,"img");return a=>(0,ni.default)(a).replace(i,o).replace(/<\/mark>(\s+)]*>/img,"$1")}}function jt(e){return e.type===1}function dr(e){return e.type===3}function ai(e,t){let r=yn(e);return O(I(location.protocol!=="file:"),ze("search")).pipe(Ae(o=>o),v(()=>t)).subscribe(({config:o,docs:n})=>r.next({type:0,data:{config:o,docs:n,options:{suggest:B("search.suggest")}}})),r}function si(e){var l;let{selectedVersionSitemap:t,selectedVersionBaseURL:r,currentLocation:o,currentBaseURL:n}=e,i=(l=Xr(n))==null?void 0:l.pathname;if(i===void 0)return;let a=ss(o.pathname,i);if(a===void 0)return;let s=ps(t.keys());if(!t.has(s))return;let p=Xr(a,s);if(!p||!t.has(p.href))return;let c=Xr(a,r);if(c)return c.hash=o.hash,c.search=o.search,c}function Xr(e,t){try{return new URL(e,t)}catch(r){return}}function ss(e,t){if(e.startsWith(t))return e.slice(t.length)}function cs(e,t){let r=Math.min(e.length,t.length),o;for(o=0;oS)),o=r.pipe(m(n=>{let[,i]=t.base.match(/([^/]+)\/?$/);return n.find(({version:a,aliases:s})=>a===i||s.includes(i))||n[0]}));r.pipe(m(n=>new Map(n.map(i=>[`${new URL(`../${i.version}/`,t.base)}`,i]))),v(n=>h(document.body,"click").pipe(b(i=>!i.metaKey&&!i.ctrlKey),re(o),v(([i,a])=>{if(i.target instanceof Element){let s=i.target.closest("a");if(s&&!s.target&&n.has(s.href)){let p=s.href;return!i.target.closest(".md-version")&&n.get(p)===a?S:(i.preventDefault(),I(new URL(p)))}}return S}),v(i=>ur(i).pipe(m(a=>{var s;return(s=si({selectedVersionSitemap:a,selectedVersionBaseURL:i,currentLocation:ye(),currentBaseURL:t.base}))!=null?s:i})))))).subscribe(n=>lt(n,!0)),z([r,o]).subscribe(([n,i])=>{R(".md-header__topic").appendChild(Cn(n,i))}),e.pipe(v(()=>o)).subscribe(n=>{var a;let i=__md_get("__outdated",sessionStorage);if(i===null){i=!0;let s=((a=t.version)==null?void 0:a.default)||"latest";Array.isArray(s)||(s=[s]);e:for(let p of s)for(let c of n.aliases.concat(n.version))if(new RegExp(p,"i").test(c)){i=!1;break e}__md_set("__outdated",i,sessionStorage)}if(i)for(let s of ae("outdated"))s.hidden=!1})}function ls(e,{worker$:t}){let{searchParams:r}=ye();r.has("q")&&(Je("search",!0),e.value=r.get("q"),e.focus(),ze("search").pipe(Ae(i=>!i)).subscribe(()=>{let i=ye();i.searchParams.delete("q"),history.replaceState({},"",`${i}`)}));let o=et(e),n=O(t.pipe(Ae(jt)),h(e,"keyup"),o).pipe(m(()=>e.value),K());return z([n,o]).pipe(m(([i,a])=>({value:i,focus:a})),G(1))}function pi(e,{worker$:t}){let r=new g,o=r.pipe(Z(),ie(!0));z([t.pipe(Ae(jt)),r],(i,a)=>a).pipe(ee("value")).subscribe(({value:i})=>t.next({type:2,data:i})),r.pipe(ee("focus")).subscribe(({focus:i})=>{i&&Je("search",i)}),h(e.form,"reset").pipe(W(o)).subscribe(()=>e.focus());let n=R("header [for=__search]");return h(n,"click").subscribe(()=>e.focus()),ls(e,{worker$:t}).pipe(w(i=>r.next(i)),_(()=>r.complete()),m(i=>$({ref:e},i)),G(1))}function li(e,{worker$:t,query$:r}){let o=new g,n=on(e.parentElement).pipe(b(Boolean)),i=e.parentElement,a=R(":scope > :first-child",e),s=R(":scope > :last-child",e);ze("search").subscribe(l=>s.setAttribute("role",l?"list":"presentation")),o.pipe(re(r),Wr(t.pipe(Ae(jt)))).subscribe(([{items:l},{value:f}])=>{switch(l.length){case 0:a.textContent=f.length?Ee("search.result.none"):Ee("search.result.placeholder");break;case 1:a.textContent=Ee("search.result.one");break;default:let u=sr(l.length);a.textContent=Ee("search.result.other",u)}});let p=o.pipe(w(()=>s.innerHTML=""),v(({items:l})=>O(I(...l.slice(0,10)),I(...l.slice(10)).pipe(Be(4),Vr(n),v(([f])=>f)))),m(Mn),pe());return p.subscribe(l=>s.appendChild(l)),p.pipe(ne(l=>{let f=fe("details",l);return typeof f=="undefined"?S:h(f,"toggle").pipe(W(o),m(()=>f))})).subscribe(l=>{l.open===!1&&l.offsetTop<=i.scrollTop&&i.scrollTo({top:l.offsetTop})}),t.pipe(b(dr),m(({data:l})=>l)).pipe(w(l=>o.next(l)),_(()=>o.complete()),m(l=>$({ref:e},l)))}function ms(e,{query$:t}){return t.pipe(m(({value:r})=>{let o=ye();return o.hash="",r=r.replace(/\s+/g,"+").replace(/&/g,"%26").replace(/=/g,"%3D"),o.search=`q=${r}`,{url:o}}))}function mi(e,t){let r=new g,o=r.pipe(Z(),ie(!0));return r.subscribe(({url:n})=>{e.setAttribute("data-clipboard-text",e.href),e.href=`${n}`}),h(e,"click").pipe(W(o)).subscribe(n=>n.preventDefault()),ms(e,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))}function fi(e,{worker$:t,keyboard$:r}){let o=new g,n=Se("search-query"),i=O(h(n,"keydown"),h(n,"focus")).pipe(ve(se),m(()=>n.value),K());return o.pipe(He(i),m(([{suggest:s},p])=>{let c=p.split(/([\s-]+)/);if(s!=null&&s.length&&c[c.length-1]){let l=s[s.length-1];l.startsWith(c[c.length-1])&&(c[c.length-1]=l)}else c.length=0;return c})).subscribe(s=>e.innerHTML=s.join("").replace(/\s/g," ")),r.pipe(b(({mode:s})=>s==="search")).subscribe(s=>{switch(s.type){case"ArrowRight":e.innerText.length&&n.selectionStart===n.value.length&&(n.value=e.innerText);break}}),t.pipe(b(dr),m(({data:s})=>s)).pipe(w(s=>o.next(s)),_(()=>o.complete()),m(()=>({ref:e})))}function ui(e,{index$:t,keyboard$:r}){let o=xe();try{let n=ai(o.search,t),i=Se("search-query",e),a=Se("search-result",e);h(e,"click").pipe(b(({target:p})=>p instanceof Element&&!!p.closest("a"))).subscribe(()=>Je("search",!1)),r.pipe(b(({mode:p})=>p==="search")).subscribe(p=>{let c=Ie();switch(p.type){case"Enter":if(c===i){let l=new Map;for(let f of P(":first-child [href]",a)){let u=f.firstElementChild;l.set(f,parseFloat(u.getAttribute("data-md-score")))}if(l.size){let[[f]]=[...l].sort(([,u],[,d])=>d-u);f.click()}p.claim()}break;case"Escape":case"Tab":Je("search",!1),i.blur();break;case"ArrowUp":case"ArrowDown":if(typeof c=="undefined")i.focus();else{let l=[i,...P(":not(details) > [href], summary, details[open] [href]",a)],f=Math.max(0,(Math.max(0,l.indexOf(c))+l.length+(p.type==="ArrowUp"?-1:1))%l.length);l[f].focus()}p.claim();break;default:i!==Ie()&&i.focus()}}),r.pipe(b(({mode:p})=>p==="global")).subscribe(p=>{switch(p.type){case"f":case"s":case"/":i.focus(),i.select(),p.claim();break}});let s=pi(i,{worker$:n});return O(s,li(a,{worker$:n,query$:s})).pipe(Re(...ae("search-share",e).map(p=>mi(p,{query$:s})),...ae("search-suggest",e).map(p=>fi(p,{worker$:n,keyboard$:r}))))}catch(n){return e.hidden=!0,Ye}}function di(e,{index$:t,location$:r}){return z([t,r.pipe(Q(ye()),b(o=>!!o.searchParams.get("h")))]).pipe(m(([o,n])=>ii(o.config)(n.searchParams.get("h"))),m(o=>{var a;let n=new Map,i=document.createNodeIterator(e,NodeFilter.SHOW_TEXT);for(let s=i.nextNode();s;s=i.nextNode())if((a=s.parentElement)!=null&&a.offsetHeight){let p=s.textContent,c=o(p);c.length>p.length&&n.set(s,c)}for(let[s,p]of n){let{childNodes:c}=x("span",null,p);s.replaceWith(...Array.from(c))}return{ref:e,nodes:n}}))}function fs(e,{viewport$:t,main$:r}){let o=e.closest(".md-grid"),n=o.offsetTop-o.parentElement.offsetTop;return z([r,t]).pipe(m(([{offset:i,height:a},{offset:{y:s}}])=>(a=a+Math.min(n,Math.max(0,s-i))-n,{height:a,locked:s>=i+n})),K((i,a)=>i.height===a.height&&i.locked===a.locked))}function Zr(e,o){var n=o,{header$:t}=n,r=so(n,["header$"]);let i=R(".md-sidebar__scrollwrap",e),{y:a}=Ve(i);return C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=s.pipe(Me(0,me));return c.pipe(re(t)).subscribe({next([{height:l},{height:f}]){i.style.height=`${l-2*a}px`,e.style.top=`${f}px`},complete(){i.style.height="",e.style.top=""}}),c.pipe(Ae()).subscribe(()=>{for(let l of P(".md-nav__link--active[href]",e)){if(!l.clientHeight)continue;let f=l.closest(".md-sidebar__scrollwrap");if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2})}}}),ue(P("label[tabindex]",e)).pipe(ne(l=>h(l,"click").pipe(ve(se),m(()=>l),W(p)))).subscribe(l=>{let f=R(`[id="${l.htmlFor}"]`);R(`[aria-labelledby="${l.id}"]`).setAttribute("aria-expanded",`${f.checked}`)}),fs(e,r).pipe(w(l=>s.next(l)),_(()=>s.complete()),m(l=>$({ref:e},l)))})}function hi(e,t){if(typeof t!="undefined"){let r=`https://api.github.com/repos/${e}/${t}`;return st(je(`${r}/releases/latest`).pipe(de(()=>S),m(o=>({version:o.tag_name})),De({})),je(r).pipe(de(()=>S),m(o=>({stars:o.stargazers_count,forks:o.forks_count})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}else{let r=`https://api.github.com/users/${e}`;return je(r).pipe(m(o=>({repositories:o.public_repos})),De({}))}}function bi(e,t){let r=`https://${e}/api/v4/projects/${encodeURIComponent(t)}`;return st(je(`${r}/releases/permalink/latest`).pipe(de(()=>S),m(({tag_name:o})=>({version:o})),De({})),je(r).pipe(de(()=>S),m(({star_count:o,forks_count:n})=>({stars:o,forks:n})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}function vi(e){let t=e.match(/^.+github\.com\/([^/]+)\/?([^/]+)?/i);if(t){let[,r,o]=t;return hi(r,o)}if(t=e.match(/^.+?([^/]*gitlab[^/]+)\/(.+?)\/?$/i),t){let[,r,o]=t;return bi(r,o)}return S}var us;function ds(e){return us||(us=C(()=>{let t=__md_get("__source",sessionStorage);if(t)return I(t);if(ae("consent").length){let o=__md_get("__consent");if(!(o&&o.github))return S}return vi(e.href).pipe(w(o=>__md_set("__source",o,sessionStorage)))}).pipe(de(()=>S),b(t=>Object.keys(t).length>0),m(t=>({facts:t})),G(1)))}function gi(e){let t=R(":scope > :last-child",e);return C(()=>{let r=new g;return r.subscribe(({facts:o})=>{t.appendChild(_n(o)),t.classList.add("md-source__repository--active")}),ds(e).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function hs(e,{viewport$:t,header$:r}){return ge(document.body).pipe(v(()=>mr(e,{header$:r,viewport$:t})),m(({offset:{y:o}})=>({hidden:o>=10})),ee("hidden"))}function yi(e,t){return C(()=>{let r=new g;return r.subscribe({next({hidden:o}){e.hidden=o},complete(){e.hidden=!1}}),(B("navigation.tabs.sticky")?I({hidden:!1}):hs(e,t)).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function bs(e,{viewport$:t,header$:r}){let o=new Map,n=P(".md-nav__link",e);for(let s of n){let p=decodeURIComponent(s.hash.substring(1)),c=fe(`[id="${p}"]`);typeof c!="undefined"&&o.set(s,c)}let i=r.pipe(ee("height"),m(({height:s})=>{let p=Se("main"),c=R(":scope > :first-child",p);return s+.8*(c.offsetTop-p.offsetTop)}),pe());return ge(document.body).pipe(ee("height"),v(s=>C(()=>{let p=[];return I([...o].reduce((c,[l,f])=>{for(;p.length&&o.get(p[p.length-1]).tagName>=f.tagName;)p.pop();let u=f.offsetTop;for(;!u&&f.parentElement;)f=f.parentElement,u=f.offsetTop;let d=f.offsetParent;for(;d;d=d.offsetParent)u+=d.offsetTop;return c.set([...p=[...p,l]].reverse(),u)},new Map))}).pipe(m(p=>new Map([...p].sort(([,c],[,l])=>c-l))),He(i),v(([p,c])=>t.pipe(Fr(([l,f],{offset:{y:u},size:d})=>{let y=u+d.height>=Math.floor(s.height);for(;f.length;){let[,L]=f[0];if(L-c=u&&!y)f=[l.pop(),...f];else break}return[l,f]},[[],[...p]]),K((l,f)=>l[0]===f[0]&&l[1]===f[1])))))).pipe(m(([s,p])=>({prev:s.map(([c])=>c),next:p.map(([c])=>c)})),Q({prev:[],next:[]}),Be(2,1),m(([s,p])=>s.prev.length{let i=new g,a=i.pipe(Z(),ie(!0));if(i.subscribe(({prev:s,next:p})=>{for(let[c]of p)c.classList.remove("md-nav__link--passed"),c.classList.remove("md-nav__link--active");for(let[c,[l]]of s.entries())l.classList.add("md-nav__link--passed"),l.classList.toggle("md-nav__link--active",c===s.length-1)}),B("toc.follow")){let s=O(t.pipe(_e(1),m(()=>{})),t.pipe(_e(250),m(()=>"smooth")));i.pipe(b(({prev:p})=>p.length>0),He(o.pipe(ve(se))),re(s)).subscribe(([[{prev:p}],c])=>{let[l]=p[p.length-1];if(l.offsetHeight){let f=cr(l);if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2,behavior:c})}}})}return B("navigation.tracking")&&t.pipe(W(a),ee("offset"),_e(250),Ce(1),W(n.pipe(Ce(1))),ct({delay:250}),re(i)).subscribe(([,{prev:s}])=>{let p=ye(),c=s[s.length-1];if(c&&c.length){let[l]=c,{hash:f}=new URL(l.href);p.hash!==f&&(p.hash=f,history.replaceState({},"",`${p}`))}else p.hash="",history.replaceState({},"",`${p}`)}),bs(e,{viewport$:t,header$:r}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function vs(e,{viewport$:t,main$:r,target$:o}){let n=t.pipe(m(({offset:{y:a}})=>a),Be(2,1),m(([a,s])=>a>s&&s>0),K()),i=r.pipe(m(({active:a})=>a));return z([i,n]).pipe(m(([a,s])=>!(a&&s)),K(),W(o.pipe(Ce(1))),ie(!0),ct({delay:250}),m(a=>({hidden:a})))}function Ei(e,{viewport$:t,header$:r,main$:o,target$:n}){let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({hidden:s}){e.hidden=s,s?(e.setAttribute("tabindex","-1"),e.blur()):e.removeAttribute("tabindex")},complete(){e.style.top="",e.hidden=!0,e.removeAttribute("tabindex")}}),r.pipe(W(a),ee("height")).subscribe(({height:s})=>{e.style.top=`${s+16}px`}),h(e,"click").subscribe(s=>{s.preventDefault(),window.scrollTo({top:0})}),vs(e,{viewport$:t,main$:o,target$:n}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))}function wi({document$:e,viewport$:t}){e.pipe(v(()=>P(".md-ellipsis")),ne(r=>tt(r).pipe(W(e.pipe(Ce(1))),b(o=>o),m(()=>r),Te(1))),b(r=>r.offsetWidth{let o=r.innerText,n=r.closest("a")||r;return n.title=o,B("content.tooltips")?mt(n,{viewport$:t}).pipe(W(e.pipe(Ce(1))),_(()=>n.removeAttribute("title"))):S})).subscribe(),B("content.tooltips")&&e.pipe(v(()=>P(".md-status")),ne(r=>mt(r,{viewport$:t}))).subscribe()}function Ti({document$:e,tablet$:t}){e.pipe(v(()=>P(".md-toggle--indeterminate")),w(r=>{r.indeterminate=!0,r.checked=!1}),ne(r=>h(r,"change").pipe(Dr(()=>r.classList.contains("md-toggle--indeterminate")),m(()=>r))),re(t)).subscribe(([r,o])=>{r.classList.remove("md-toggle--indeterminate"),o&&(r.checked=!1)})}function gs(){return/(iPad|iPhone|iPod)/.test(navigator.userAgent)}function Si({document$:e}){e.pipe(v(()=>P("[data-md-scrollfix]")),w(t=>t.removeAttribute("data-md-scrollfix")),b(gs),ne(t=>h(t,"touchstart").pipe(m(()=>t)))).subscribe(t=>{let r=t.scrollTop;r===0?t.scrollTop=1:r+t.offsetHeight===t.scrollHeight&&(t.scrollTop=r-1)})}function Oi({viewport$:e,tablet$:t}){z([ze("search"),t]).pipe(m(([r,o])=>r&&!o),v(r=>I(r).pipe(Ge(r?400:100))),re(e)).subscribe(([r,{offset:{y:o}}])=>{if(r)document.body.setAttribute("data-md-scrolllock",""),document.body.style.top=`-${o}px`;else{let n=-1*parseInt(document.body.style.top,10);document.body.removeAttribute("data-md-scrolllock"),document.body.style.top="",n&&window.scrollTo(0,n)}})}Object.entries||(Object.entries=function(e){let t=[];for(let r of Object.keys(e))t.push([r,e[r]]);return t});Object.values||(Object.values=function(e){let t=[];for(let r of Object.keys(e))t.push(e[r]);return t});typeof Element!="undefined"&&(Element.prototype.scrollTo||(Element.prototype.scrollTo=function(e,t){typeof e=="object"?(this.scrollLeft=e.left,this.scrollTop=e.top):(this.scrollLeft=e,this.scrollTop=t)}),Element.prototype.replaceWith||(Element.prototype.replaceWith=function(...e){let t=this.parentNode;if(t){e.length===0&&t.removeChild(this);for(let r=e.length-1;r>=0;r--){let o=e[r];typeof o=="string"?o=document.createTextNode(o):o.parentNode&&o.parentNode.removeChild(o),r?t.insertBefore(this.previousSibling,o):t.replaceChild(o,this)}}}));function ys(){return location.protocol==="file:"?Tt(`${new URL("search/search_index.js",eo.base)}`).pipe(m(()=>__index),G(1)):je(new URL("search/search_index.json",eo.base))}document.documentElement.classList.remove("no-js");document.documentElement.classList.add("js");var ot=Go(),Ut=sn(),Lt=ln(Ut),to=an(),Oe=gn(),hr=Pt("(min-width: 960px)"),Mi=Pt("(min-width: 1220px)"),_i=mn(),eo=xe(),Ai=document.forms.namedItem("search")?ys():Ye,ro=new g;Zn({alert$:ro});var oo=new g;B("navigation.instant")&&oi({location$:Ut,viewport$:Oe,progress$:oo}).subscribe(ot);var Li;((Li=eo.version)==null?void 0:Li.provider)==="mike"&&ci({document$:ot});O(Ut,Lt).pipe(Ge(125)).subscribe(()=>{Je("drawer",!1),Je("search",!1)});to.pipe(b(({mode:e})=>e==="global")).subscribe(e=>{switch(e.type){case"p":case",":let t=fe("link[rel=prev]");typeof t!="undefined"&<(t);break;case"n":case".":let r=fe("link[rel=next]");typeof r!="undefined"&<(r);break;case"Enter":let o=Ie();o instanceof HTMLLabelElement&&o.click()}});wi({viewport$:Oe,document$:ot});Ti({document$:ot,tablet$:hr});Si({document$:ot});Oi({viewport$:Oe,tablet$:hr});var rt=Kn(Se("header"),{viewport$:Oe}),Ft=ot.pipe(m(()=>Se("main")),v(e=>Gn(e,{viewport$:Oe,header$:rt})),G(1)),xs=O(...ae("consent").map(e=>En(e,{target$:Lt})),...ae("dialog").map(e=>qn(e,{alert$:ro})),...ae("palette").map(e=>Jn(e)),...ae("progress").map(e=>Xn(e,{progress$:oo})),...ae("search").map(e=>ui(e,{index$:Ai,keyboard$:to})),...ae("source").map(e=>gi(e))),Es=C(()=>O(...ae("announce").map(e=>xn(e)),...ae("content").map(e=>zn(e,{viewport$:Oe,target$:Lt,print$:_i})),...ae("content").map(e=>B("search.highlight")?di(e,{index$:Ai,location$:Ut}):S),...ae("header").map(e=>Yn(e,{viewport$:Oe,header$:rt,main$:Ft})),...ae("header-title").map(e=>Bn(e,{viewport$:Oe,header$:rt})),...ae("sidebar").map(e=>e.getAttribute("data-md-type")==="navigation"?Nr(Mi,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft})):Nr(hr,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft}))),...ae("tabs").map(e=>yi(e,{viewport$:Oe,header$:rt})),...ae("toc").map(e=>xi(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})),...ae("top").map(e=>Ei(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})))),Ci=ot.pipe(v(()=>Es),Re(xs),G(1));Ci.subscribe();window.document$=ot;window.location$=Ut;window.target$=Lt;window.keyboard$=to;window.viewport$=Oe;window.tablet$=hr;window.screen$=Mi;window.print$=_i;window.alert$=ro;window.progress$=oo;window.component$=Ci;})(); +//# sourceMappingURL=bundle.83f73b43.min.js.map + diff --git a/assets/javascripts/bundle.83f73b43.min.js.map b/assets/javascripts/bundle.83f73b43.min.js.map new file mode 100644 index 00000000..fe920b7d --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js.map @@ -0,0 +1,7 @@ +{ + "version": 3, + "sources": ["node_modules/focus-visible/dist/focus-visible.js", "node_modules/escape-html/index.js", "node_modules/clipboard/dist/clipboard.js", "src/templates/assets/javascripts/bundle.ts", "node_modules/tslib/tslib.es6.mjs", "node_modules/rxjs/src/internal/util/isFunction.ts", "node_modules/rxjs/src/internal/util/createErrorClass.ts", "node_modules/rxjs/src/internal/util/UnsubscriptionError.ts", "node_modules/rxjs/src/internal/util/arrRemove.ts", "node_modules/rxjs/src/internal/Subscription.ts", "node_modules/rxjs/src/internal/config.ts", "node_modules/rxjs/src/internal/scheduler/timeoutProvider.ts", "node_modules/rxjs/src/internal/util/reportUnhandledError.ts", "node_modules/rxjs/src/internal/util/noop.ts", "node_modules/rxjs/src/internal/NotificationFactories.ts", "node_modules/rxjs/src/internal/util/errorContext.ts", "node_modules/rxjs/src/internal/Subscriber.ts", "node_modules/rxjs/src/internal/symbol/observable.ts", "node_modules/rxjs/src/internal/util/identity.ts", "node_modules/rxjs/src/internal/util/pipe.ts", "node_modules/rxjs/src/internal/Observable.ts", "node_modules/rxjs/src/internal/util/lift.ts", "node_modules/rxjs/src/internal/operators/OperatorSubscriber.ts", "node_modules/rxjs/src/internal/scheduler/animationFrameProvider.ts", "node_modules/rxjs/src/internal/util/ObjectUnsubscribedError.ts", "node_modules/rxjs/src/internal/Subject.ts", "node_modules/rxjs/src/internal/BehaviorSubject.ts", "node_modules/rxjs/src/internal/scheduler/dateTimestampProvider.ts", "node_modules/rxjs/src/internal/ReplaySubject.ts", "node_modules/rxjs/src/internal/scheduler/Action.ts", "node_modules/rxjs/src/internal/scheduler/intervalProvider.ts", "node_modules/rxjs/src/internal/scheduler/AsyncAction.ts", "node_modules/rxjs/src/internal/Scheduler.ts", "node_modules/rxjs/src/internal/scheduler/AsyncScheduler.ts", "node_modules/rxjs/src/internal/scheduler/async.ts", "node_modules/rxjs/src/internal/scheduler/QueueAction.ts", "node_modules/rxjs/src/internal/scheduler/QueueScheduler.ts", "node_modules/rxjs/src/internal/scheduler/queue.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameAction.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameScheduler.ts", "node_modules/rxjs/src/internal/scheduler/animationFrame.ts", "node_modules/rxjs/src/internal/observable/empty.ts", "node_modules/rxjs/src/internal/util/isScheduler.ts", "node_modules/rxjs/src/internal/util/args.ts", "node_modules/rxjs/src/internal/util/isArrayLike.ts", "node_modules/rxjs/src/internal/util/isPromise.ts", "node_modules/rxjs/src/internal/util/isInteropObservable.ts", "node_modules/rxjs/src/internal/util/isAsyncIterable.ts", "node_modules/rxjs/src/internal/util/throwUnobservableError.ts", "node_modules/rxjs/src/internal/symbol/iterator.ts", "node_modules/rxjs/src/internal/util/isIterable.ts", "node_modules/rxjs/src/internal/util/isReadableStreamLike.ts", "node_modules/rxjs/src/internal/observable/innerFrom.ts", "node_modules/rxjs/src/internal/util/executeSchedule.ts", "node_modules/rxjs/src/internal/operators/observeOn.ts", "node_modules/rxjs/src/internal/operators/subscribeOn.ts", "node_modules/rxjs/src/internal/scheduled/scheduleObservable.ts", "node_modules/rxjs/src/internal/scheduled/schedulePromise.ts", "node_modules/rxjs/src/internal/scheduled/scheduleArray.ts", "node_modules/rxjs/src/internal/scheduled/scheduleIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleAsyncIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleReadableStreamLike.ts", "node_modules/rxjs/src/internal/scheduled/scheduled.ts", "node_modules/rxjs/src/internal/observable/from.ts", "node_modules/rxjs/src/internal/observable/of.ts", "node_modules/rxjs/src/internal/observable/throwError.ts", "node_modules/rxjs/src/internal/util/EmptyError.ts", "node_modules/rxjs/src/internal/util/isDate.ts", "node_modules/rxjs/src/internal/operators/map.ts", "node_modules/rxjs/src/internal/util/mapOneOrManyArgs.ts", "node_modules/rxjs/src/internal/util/argsArgArrayOrObject.ts", "node_modules/rxjs/src/internal/util/createObject.ts", "node_modules/rxjs/src/internal/observable/combineLatest.ts", "node_modules/rxjs/src/internal/operators/mergeInternals.ts", "node_modules/rxjs/src/internal/operators/mergeMap.ts", "node_modules/rxjs/src/internal/operators/mergeAll.ts", "node_modules/rxjs/src/internal/operators/concatAll.ts", "node_modules/rxjs/src/internal/observable/concat.ts", "node_modules/rxjs/src/internal/observable/defer.ts", "node_modules/rxjs/src/internal/observable/fromEvent.ts", "node_modules/rxjs/src/internal/observable/fromEventPattern.ts", "node_modules/rxjs/src/internal/observable/timer.ts", "node_modules/rxjs/src/internal/observable/merge.ts", "node_modules/rxjs/src/internal/observable/never.ts", "node_modules/rxjs/src/internal/util/argsOrArgArray.ts", "node_modules/rxjs/src/internal/operators/filter.ts", "node_modules/rxjs/src/internal/observable/zip.ts", "node_modules/rxjs/src/internal/operators/audit.ts", "node_modules/rxjs/src/internal/operators/auditTime.ts", "node_modules/rxjs/src/internal/operators/bufferCount.ts", "node_modules/rxjs/src/internal/operators/catchError.ts", "node_modules/rxjs/src/internal/operators/scanInternals.ts", "node_modules/rxjs/src/internal/operators/combineLatest.ts", "node_modules/rxjs/src/internal/operators/combineLatestWith.ts", "node_modules/rxjs/src/internal/operators/debounce.ts", "node_modules/rxjs/src/internal/operators/debounceTime.ts", "node_modules/rxjs/src/internal/operators/defaultIfEmpty.ts", "node_modules/rxjs/src/internal/operators/take.ts", "node_modules/rxjs/src/internal/operators/ignoreElements.ts", "node_modules/rxjs/src/internal/operators/mapTo.ts", "node_modules/rxjs/src/internal/operators/delayWhen.ts", "node_modules/rxjs/src/internal/operators/delay.ts", "node_modules/rxjs/src/internal/operators/distinctUntilChanged.ts", "node_modules/rxjs/src/internal/operators/distinctUntilKeyChanged.ts", "node_modules/rxjs/src/internal/operators/throwIfEmpty.ts", "node_modules/rxjs/src/internal/operators/endWith.ts", "node_modules/rxjs/src/internal/operators/finalize.ts", "node_modules/rxjs/src/internal/operators/first.ts", "node_modules/rxjs/src/internal/operators/takeLast.ts", "node_modules/rxjs/src/internal/operators/merge.ts", "node_modules/rxjs/src/internal/operators/mergeWith.ts", "node_modules/rxjs/src/internal/operators/repeat.ts", "node_modules/rxjs/src/internal/operators/scan.ts", "node_modules/rxjs/src/internal/operators/share.ts", "node_modules/rxjs/src/internal/operators/shareReplay.ts", "node_modules/rxjs/src/internal/operators/skip.ts", "node_modules/rxjs/src/internal/operators/skipUntil.ts", "node_modules/rxjs/src/internal/operators/startWith.ts", "node_modules/rxjs/src/internal/operators/switchMap.ts", "node_modules/rxjs/src/internal/operators/takeUntil.ts", "node_modules/rxjs/src/internal/operators/takeWhile.ts", "node_modules/rxjs/src/internal/operators/tap.ts", "node_modules/rxjs/src/internal/operators/throttle.ts", "node_modules/rxjs/src/internal/operators/throttleTime.ts", "node_modules/rxjs/src/internal/operators/withLatestFrom.ts", "node_modules/rxjs/src/internal/operators/zip.ts", "node_modules/rxjs/src/internal/operators/zipWith.ts", "src/templates/assets/javascripts/browser/document/index.ts", "src/templates/assets/javascripts/browser/element/_/index.ts", "src/templates/assets/javascripts/browser/element/focus/index.ts", "src/templates/assets/javascripts/browser/element/hover/index.ts", "src/templates/assets/javascripts/utilities/h/index.ts", "src/templates/assets/javascripts/utilities/round/index.ts", "src/templates/assets/javascripts/browser/script/index.ts", "src/templates/assets/javascripts/browser/element/size/_/index.ts", "src/templates/assets/javascripts/browser/element/size/content/index.ts", "src/templates/assets/javascripts/browser/element/offset/_/index.ts", "src/templates/assets/javascripts/browser/element/offset/content/index.ts", "src/templates/assets/javascripts/browser/element/visibility/index.ts", "src/templates/assets/javascripts/browser/toggle/index.ts", "src/templates/assets/javascripts/browser/keyboard/index.ts", "src/templates/assets/javascripts/browser/location/_/index.ts", "src/templates/assets/javascripts/browser/location/hash/index.ts", "src/templates/assets/javascripts/browser/media/index.ts", "src/templates/assets/javascripts/browser/request/index.ts", "src/templates/assets/javascripts/browser/viewport/offset/index.ts", "src/templates/assets/javascripts/browser/viewport/size/index.ts", "src/templates/assets/javascripts/browser/viewport/_/index.ts", "src/templates/assets/javascripts/browser/viewport/at/index.ts", "src/templates/assets/javascripts/browser/worker/index.ts", "src/templates/assets/javascripts/_/index.ts", "src/templates/assets/javascripts/components/_/index.ts", "src/templates/assets/javascripts/components/announce/index.ts", "src/templates/assets/javascripts/components/consent/index.ts", "src/templates/assets/javascripts/templates/tooltip/index.tsx", "src/templates/assets/javascripts/templates/annotation/index.tsx", "src/templates/assets/javascripts/templates/clipboard/index.tsx", "src/templates/assets/javascripts/templates/search/index.tsx", "src/templates/assets/javascripts/templates/source/index.tsx", "src/templates/assets/javascripts/templates/tabbed/index.tsx", "src/templates/assets/javascripts/templates/table/index.tsx", "src/templates/assets/javascripts/templates/version/index.tsx", "src/templates/assets/javascripts/components/tooltip2/index.ts", "src/templates/assets/javascripts/components/content/annotation/_/index.ts", "src/templates/assets/javascripts/components/content/annotation/list/index.ts", "src/templates/assets/javascripts/components/content/annotation/block/index.ts", "src/templates/assets/javascripts/components/content/code/_/index.ts", "src/templates/assets/javascripts/components/content/details/index.ts", "src/templates/assets/javascripts/components/content/mermaid/index.css", "src/templates/assets/javascripts/components/content/mermaid/index.ts", "src/templates/assets/javascripts/components/content/table/index.ts", "src/templates/assets/javascripts/components/content/tabs/index.ts", "src/templates/assets/javascripts/components/content/_/index.ts", "src/templates/assets/javascripts/components/dialog/index.ts", "src/templates/assets/javascripts/components/tooltip/index.ts", "src/templates/assets/javascripts/components/header/_/index.ts", "src/templates/assets/javascripts/components/header/title/index.ts", "src/templates/assets/javascripts/components/main/index.ts", "src/templates/assets/javascripts/components/palette/index.ts", "src/templates/assets/javascripts/components/progress/index.ts", "src/templates/assets/javascripts/integrations/clipboard/index.ts", "src/templates/assets/javascripts/integrations/sitemap/index.ts", "src/templates/assets/javascripts/integrations/instant/index.ts", "src/templates/assets/javascripts/integrations/search/highlighter/index.ts", "src/templates/assets/javascripts/integrations/search/worker/message/index.ts", "src/templates/assets/javascripts/integrations/search/worker/_/index.ts", "src/templates/assets/javascripts/integrations/version/findurl/index.ts", "src/templates/assets/javascripts/integrations/version/index.ts", "src/templates/assets/javascripts/components/search/query/index.ts", "src/templates/assets/javascripts/components/search/result/index.ts", "src/templates/assets/javascripts/components/search/share/index.ts", "src/templates/assets/javascripts/components/search/suggest/index.ts", "src/templates/assets/javascripts/components/search/_/index.ts", "src/templates/assets/javascripts/components/search/highlight/index.ts", "src/templates/assets/javascripts/components/sidebar/index.ts", "src/templates/assets/javascripts/components/source/facts/github/index.ts", "src/templates/assets/javascripts/components/source/facts/gitlab/index.ts", "src/templates/assets/javascripts/components/source/facts/_/index.ts", "src/templates/assets/javascripts/components/source/_/index.ts", "src/templates/assets/javascripts/components/tabs/index.ts", "src/templates/assets/javascripts/components/toc/index.ts", "src/templates/assets/javascripts/components/top/index.ts", "src/templates/assets/javascripts/patches/ellipsis/index.ts", "src/templates/assets/javascripts/patches/indeterminate/index.ts", "src/templates/assets/javascripts/patches/scrollfix/index.ts", "src/templates/assets/javascripts/patches/scrolllock/index.ts", "src/templates/assets/javascripts/polyfills/index.ts"], + "sourcesContent": ["(function (global, factory) {\n typeof exports === 'object' && typeof module !== 'undefined' ? factory() :\n typeof define === 'function' && define.amd ? define(factory) :\n (factory());\n}(this, (function () { 'use strict';\n\n /**\n * Applies the :focus-visible polyfill at the given scope.\n * A scope in this case is either the top-level Document or a Shadow Root.\n *\n * @param {(Document|ShadowRoot)} scope\n * @see https://github.com/WICG/focus-visible\n */\n function applyFocusVisiblePolyfill(scope) {\n var hadKeyboardEvent = true;\n var hadFocusVisibleRecently = false;\n var hadFocusVisibleRecentlyTimeout = null;\n\n var inputTypesAllowlist = {\n text: true,\n search: true,\n url: true,\n tel: true,\n email: true,\n password: true,\n number: true,\n date: true,\n month: true,\n week: true,\n time: true,\n datetime: true,\n 'datetime-local': true\n };\n\n /**\n * Helper function for legacy browsers and iframes which sometimes focus\n * elements like document, body, and non-interactive SVG.\n * @param {Element} el\n */\n function isValidFocusTarget(el) {\n if (\n el &&\n el !== document &&\n el.nodeName !== 'HTML' &&\n el.nodeName !== 'BODY' &&\n 'classList' in el &&\n 'contains' in el.classList\n ) {\n return true;\n }\n return false;\n }\n\n /**\n * Computes whether the given element should automatically trigger the\n * `focus-visible` class being added, i.e. whether it should always match\n * `:focus-visible` when focused.\n * @param {Element} el\n * @return {boolean}\n */\n function focusTriggersKeyboardModality(el) {\n var type = el.type;\n var tagName = el.tagName;\n\n if (tagName === 'INPUT' && inputTypesAllowlist[type] && !el.readOnly) {\n return true;\n }\n\n if (tagName === 'TEXTAREA' && !el.readOnly) {\n return true;\n }\n\n if (el.isContentEditable) {\n return true;\n }\n\n return false;\n }\n\n /**\n * Add the `focus-visible` class to the given element if it was not added by\n * the author.\n * @param {Element} el\n */\n function addFocusVisibleClass(el) {\n if (el.classList.contains('focus-visible')) {\n return;\n }\n el.classList.add('focus-visible');\n el.setAttribute('data-focus-visible-added', '');\n }\n\n /**\n * Remove the `focus-visible` class from the given element if it was not\n * originally added by the author.\n * @param {Element} el\n */\n function removeFocusVisibleClass(el) {\n if (!el.hasAttribute('data-focus-visible-added')) {\n return;\n }\n el.classList.remove('focus-visible');\n el.removeAttribute('data-focus-visible-added');\n }\n\n /**\n * If the most recent user interaction was via the keyboard;\n * and the key press did not include a meta, alt/option, or control key;\n * then the modality is keyboard. Otherwise, the modality is not keyboard.\n * Apply `focus-visible` to any current active element and keep track\n * of our keyboard modality state with `hadKeyboardEvent`.\n * @param {KeyboardEvent} e\n */\n function onKeyDown(e) {\n if (e.metaKey || e.altKey || e.ctrlKey) {\n return;\n }\n\n if (isValidFocusTarget(scope.activeElement)) {\n addFocusVisibleClass(scope.activeElement);\n }\n\n hadKeyboardEvent = true;\n }\n\n /**\n * If at any point a user clicks with a pointing device, ensure that we change\n * the modality away from keyboard.\n * This avoids the situation where a user presses a key on an already focused\n * element, and then clicks on a different element, focusing it with a\n * pointing device, while we still think we're in keyboard modality.\n * @param {Event} e\n */\n function onPointerDown(e) {\n hadKeyboardEvent = false;\n }\n\n /**\n * On `focus`, add the `focus-visible` class to the target if:\n * - the target received focus as a result of keyboard navigation, or\n * - the event target is an element that will likely require interaction\n * via the keyboard (e.g. a text box)\n * @param {Event} e\n */\n function onFocus(e) {\n // Prevent IE from focusing the document or HTML element.\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (hadKeyboardEvent || focusTriggersKeyboardModality(e.target)) {\n addFocusVisibleClass(e.target);\n }\n }\n\n /**\n * On `blur`, remove the `focus-visible` class from the target.\n * @param {Event} e\n */\n function onBlur(e) {\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (\n e.target.classList.contains('focus-visible') ||\n e.target.hasAttribute('data-focus-visible-added')\n ) {\n // To detect a tab/window switch, we look for a blur event followed\n // rapidly by a visibility change.\n // If we don't see a visibility change within 100ms, it's probably a\n // regular focus change.\n hadFocusVisibleRecently = true;\n window.clearTimeout(hadFocusVisibleRecentlyTimeout);\n hadFocusVisibleRecentlyTimeout = window.setTimeout(function() {\n hadFocusVisibleRecently = false;\n }, 100);\n removeFocusVisibleClass(e.target);\n }\n }\n\n /**\n * If the user changes tabs, keep track of whether or not the previously\n * focused element had .focus-visible.\n * @param {Event} e\n */\n function onVisibilityChange(e) {\n if (document.visibilityState === 'hidden') {\n // If the tab becomes active again, the browser will handle calling focus\n // on the element (Safari actually calls it twice).\n // If this tab change caused a blur on an element with focus-visible,\n // re-apply the class when the user switches back to the tab.\n if (hadFocusVisibleRecently) {\n hadKeyboardEvent = true;\n }\n addInitialPointerMoveListeners();\n }\n }\n\n /**\n * Add a group of listeners to detect usage of any pointing devices.\n * These listeners will be added when the polyfill first loads, and anytime\n * the window is blurred, so that they are active when the window regains\n * focus.\n */\n function addInitialPointerMoveListeners() {\n document.addEventListener('mousemove', onInitialPointerMove);\n document.addEventListener('mousedown', onInitialPointerMove);\n document.addEventListener('mouseup', onInitialPointerMove);\n document.addEventListener('pointermove', onInitialPointerMove);\n document.addEventListener('pointerdown', onInitialPointerMove);\n document.addEventListener('pointerup', onInitialPointerMove);\n document.addEventListener('touchmove', onInitialPointerMove);\n document.addEventListener('touchstart', onInitialPointerMove);\n document.addEventListener('touchend', onInitialPointerMove);\n }\n\n function removeInitialPointerMoveListeners() {\n document.removeEventListener('mousemove', onInitialPointerMove);\n document.removeEventListener('mousedown', onInitialPointerMove);\n document.removeEventListener('mouseup', onInitialPointerMove);\n document.removeEventListener('pointermove', onInitialPointerMove);\n document.removeEventListener('pointerdown', onInitialPointerMove);\n document.removeEventListener('pointerup', onInitialPointerMove);\n document.removeEventListener('touchmove', onInitialPointerMove);\n document.removeEventListener('touchstart', onInitialPointerMove);\n document.removeEventListener('touchend', onInitialPointerMove);\n }\n\n /**\n * When the polfyill first loads, assume the user is in keyboard modality.\n * If any event is received from a pointing device (e.g. mouse, pointer,\n * touch), turn off keyboard modality.\n * This accounts for situations where focus enters the page from the URL bar.\n * @param {Event} e\n */\n function onInitialPointerMove(e) {\n // Work around a Safari quirk that fires a mousemove on whenever the\n // window blurs, even if you're tabbing out of the page. \u00AF\\_(\u30C4)_/\u00AF\n if (e.target.nodeName && e.target.nodeName.toLowerCase() === 'html') {\n return;\n }\n\n hadKeyboardEvent = false;\n removeInitialPointerMoveListeners();\n }\n\n // For some kinds of state, we are interested in changes at the global scope\n // only. For example, global pointer input, global key presses and global\n // visibility change should affect the state at every scope:\n document.addEventListener('keydown', onKeyDown, true);\n document.addEventListener('mousedown', onPointerDown, true);\n document.addEventListener('pointerdown', onPointerDown, true);\n document.addEventListener('touchstart', onPointerDown, true);\n document.addEventListener('visibilitychange', onVisibilityChange, true);\n\n addInitialPointerMoveListeners();\n\n // For focus and blur, we specifically care about state changes in the local\n // scope. This is because focus / blur events that originate from within a\n // shadow root are not re-dispatched from the host element if it was already\n // the active element in its own scope:\n scope.addEventListener('focus', onFocus, true);\n scope.addEventListener('blur', onBlur, true);\n\n // We detect that a node is a ShadowRoot by ensuring that it is a\n // DocumentFragment and also has a host property. This check covers native\n // implementation and polyfill implementation transparently. If we only cared\n // about the native implementation, we could just check if the scope was\n // an instance of a ShadowRoot.\n if (scope.nodeType === Node.DOCUMENT_FRAGMENT_NODE && scope.host) {\n // Since a ShadowRoot is a special kind of DocumentFragment, it does not\n // have a root element to add a class to. So, we add this attribute to the\n // host element instead:\n scope.host.setAttribute('data-js-focus-visible', '');\n } else if (scope.nodeType === Node.DOCUMENT_NODE) {\n document.documentElement.classList.add('js-focus-visible');\n document.documentElement.setAttribute('data-js-focus-visible', '');\n }\n }\n\n // It is important to wrap all references to global window and document in\n // these checks to support server-side rendering use cases\n // @see https://github.com/WICG/focus-visible/issues/199\n if (typeof window !== 'undefined' && typeof document !== 'undefined') {\n // Make the polyfill helper globally available. This can be used as a signal\n // to interested libraries that wish to coordinate with the polyfill for e.g.,\n // applying the polyfill to a shadow root:\n window.applyFocusVisiblePolyfill = applyFocusVisiblePolyfill;\n\n // Notify interested libraries of the polyfill's presence, in case the\n // polyfill was loaded lazily:\n var event;\n\n try {\n event = new CustomEvent('focus-visible-polyfill-ready');\n } catch (error) {\n // IE11 does not support using CustomEvent as a constructor directly:\n event = document.createEvent('CustomEvent');\n event.initCustomEvent('focus-visible-polyfill-ready', false, false, {});\n }\n\n window.dispatchEvent(event);\n }\n\n if (typeof document !== 'undefined') {\n // Apply the polyfill to the global document, so that no JavaScript\n // coordination is required to use the polyfill in the top-level document:\n applyFocusVisiblePolyfill(document);\n }\n\n})));\n", "/*!\n * escape-html\n * Copyright(c) 2012-2013 TJ Holowaychuk\n * Copyright(c) 2015 Andreas Lubbe\n * Copyright(c) 2015 Tiancheng \"Timothy\" Gu\n * MIT Licensed\n */\n\n'use strict';\n\n/**\n * Module variables.\n * @private\n */\n\nvar matchHtmlRegExp = /[\"'&<>]/;\n\n/**\n * Module exports.\n * @public\n */\n\nmodule.exports = escapeHtml;\n\n/**\n * Escape special characters in the given string of html.\n *\n * @param {string} string The string to escape for inserting into HTML\n * @return {string}\n * @public\n */\n\nfunction escapeHtml(string) {\n var str = '' + string;\n var match = matchHtmlRegExp.exec(str);\n\n if (!match) {\n return str;\n }\n\n var escape;\n var html = '';\n var index = 0;\n var lastIndex = 0;\n\n for (index = match.index; index < str.length; index++) {\n switch (str.charCodeAt(index)) {\n case 34: // \"\n escape = '"';\n break;\n case 38: // &\n escape = '&';\n break;\n case 39: // '\n escape = ''';\n break;\n case 60: // <\n escape = '<';\n break;\n case 62: // >\n escape = '>';\n break;\n default:\n continue;\n }\n\n if (lastIndex !== index) {\n html += str.substring(lastIndex, index);\n }\n\n lastIndex = index + 1;\n html += escape;\n }\n\n return lastIndex !== index\n ? html + str.substring(lastIndex, index)\n : html;\n}\n", "/*!\n * clipboard.js v2.0.11\n * https://clipboardjs.com/\n *\n * Licensed MIT \u00A9 Zeno Rocha\n */\n(function webpackUniversalModuleDefinition(root, factory) {\n\tif(typeof exports === 'object' && typeof module === 'object')\n\t\tmodule.exports = factory();\n\telse if(typeof define === 'function' && define.amd)\n\t\tdefine([], factory);\n\telse if(typeof exports === 'object')\n\t\texports[\"ClipboardJS\"] = factory();\n\telse\n\t\troot[\"ClipboardJS\"] = factory();\n})(this, function() {\nreturn /******/ (function() { // webpackBootstrap\n/******/ \tvar __webpack_modules__ = ({\n\n/***/ 686:\n/***/ (function(__unused_webpack_module, __webpack_exports__, __webpack_require__) {\n\n\"use strict\";\n\n// EXPORTS\n__webpack_require__.d(__webpack_exports__, {\n \"default\": function() { return /* binding */ clipboard; }\n});\n\n// EXTERNAL MODULE: ./node_modules/tiny-emitter/index.js\nvar tiny_emitter = __webpack_require__(279);\nvar tiny_emitter_default = /*#__PURE__*/__webpack_require__.n(tiny_emitter);\n// EXTERNAL MODULE: ./node_modules/good-listener/src/listen.js\nvar listen = __webpack_require__(370);\nvar listen_default = /*#__PURE__*/__webpack_require__.n(listen);\n// EXTERNAL MODULE: ./node_modules/select/src/select.js\nvar src_select = __webpack_require__(817);\nvar select_default = /*#__PURE__*/__webpack_require__.n(src_select);\n;// CONCATENATED MODULE: ./src/common/command.js\n/**\n * Executes a given operation type.\n * @param {String} type\n * @return {Boolean}\n */\nfunction command(type) {\n try {\n return document.execCommand(type);\n } catch (err) {\n return false;\n }\n}\n;// CONCATENATED MODULE: ./src/actions/cut.js\n\n\n/**\n * Cut action wrapper.\n * @param {String|HTMLElement} target\n * @return {String}\n */\n\nvar ClipboardActionCut = function ClipboardActionCut(target) {\n var selectedText = select_default()(target);\n command('cut');\n return selectedText;\n};\n\n/* harmony default export */ var actions_cut = (ClipboardActionCut);\n;// CONCATENATED MODULE: ./src/common/create-fake-element.js\n/**\n * Creates a fake textarea element with a value.\n * @param {String} value\n * @return {HTMLElement}\n */\nfunction createFakeElement(value) {\n var isRTL = document.documentElement.getAttribute('dir') === 'rtl';\n var fakeElement = document.createElement('textarea'); // Prevent zooming on iOS\n\n fakeElement.style.fontSize = '12pt'; // Reset box model\n\n fakeElement.style.border = '0';\n fakeElement.style.padding = '0';\n fakeElement.style.margin = '0'; // Move element out of screen horizontally\n\n fakeElement.style.position = 'absolute';\n fakeElement.style[isRTL ? 'right' : 'left'] = '-9999px'; // Move element to the same position vertically\n\n var yPosition = window.pageYOffset || document.documentElement.scrollTop;\n fakeElement.style.top = \"\".concat(yPosition, \"px\");\n fakeElement.setAttribute('readonly', '');\n fakeElement.value = value;\n return fakeElement;\n}\n;// CONCATENATED MODULE: ./src/actions/copy.js\n\n\n\n/**\n * Create fake copy action wrapper using a fake element.\n * @param {String} target\n * @param {Object} options\n * @return {String}\n */\n\nvar fakeCopyAction = function fakeCopyAction(value, options) {\n var fakeElement = createFakeElement(value);\n options.container.appendChild(fakeElement);\n var selectedText = select_default()(fakeElement);\n command('copy');\n fakeElement.remove();\n return selectedText;\n};\n/**\n * Copy action wrapper.\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @return {String}\n */\n\n\nvar ClipboardActionCopy = function ClipboardActionCopy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n var selectedText = '';\n\n if (typeof target === 'string') {\n selectedText = fakeCopyAction(target, options);\n } else if (target instanceof HTMLInputElement && !['text', 'search', 'url', 'tel', 'password'].includes(target === null || target === void 0 ? void 0 : target.type)) {\n // If input type doesn't support `setSelectionRange`. Simulate it. https://developer.mozilla.org/en-US/docs/Web/API/HTMLInputElement/setSelectionRange\n selectedText = fakeCopyAction(target.value, options);\n } else {\n selectedText = select_default()(target);\n command('copy');\n }\n\n return selectedText;\n};\n\n/* harmony default export */ var actions_copy = (ClipboardActionCopy);\n;// CONCATENATED MODULE: ./src/actions/default.js\nfunction _typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { _typeof = function _typeof(obj) { return typeof obj; }; } else { _typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return _typeof(obj); }\n\n\n\n/**\n * Inner function which performs selection from either `text` or `target`\n * properties and then executes copy or cut operations.\n * @param {Object} options\n */\n\nvar ClipboardActionDefault = function ClipboardActionDefault() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n // Defines base properties passed from constructor.\n var _options$action = options.action,\n action = _options$action === void 0 ? 'copy' : _options$action,\n container = options.container,\n target = options.target,\n text = options.text; // Sets the `action` to be performed which can be either 'copy' or 'cut'.\n\n if (action !== 'copy' && action !== 'cut') {\n throw new Error('Invalid \"action\" value, use either \"copy\" or \"cut\"');\n } // Sets the `target` property using an element that will be have its content copied.\n\n\n if (target !== undefined) {\n if (target && _typeof(target) === 'object' && target.nodeType === 1) {\n if (action === 'copy' && target.hasAttribute('disabled')) {\n throw new Error('Invalid \"target\" attribute. Please use \"readonly\" instead of \"disabled\" attribute');\n }\n\n if (action === 'cut' && (target.hasAttribute('readonly') || target.hasAttribute('disabled'))) {\n throw new Error('Invalid \"target\" attribute. You can\\'t cut text from elements with \"readonly\" or \"disabled\" attributes');\n }\n } else {\n throw new Error('Invalid \"target\" value, use a valid Element');\n }\n } // Define selection strategy based on `text` property.\n\n\n if (text) {\n return actions_copy(text, {\n container: container\n });\n } // Defines which selection strategy based on `target` property.\n\n\n if (target) {\n return action === 'cut' ? actions_cut(target) : actions_copy(target, {\n container: container\n });\n }\n};\n\n/* harmony default export */ var actions_default = (ClipboardActionDefault);\n;// CONCATENATED MODULE: ./src/clipboard.js\nfunction clipboard_typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { clipboard_typeof = function _typeof(obj) { return typeof obj; }; } else { clipboard_typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return clipboard_typeof(obj); }\n\nfunction _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError(\"Cannot call a class as a function\"); } }\n\nfunction _defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if (\"value\" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } }\n\nfunction _createClass(Constructor, protoProps, staticProps) { if (protoProps) _defineProperties(Constructor.prototype, protoProps); if (staticProps) _defineProperties(Constructor, staticProps); return Constructor; }\n\nfunction _inherits(subClass, superClass) { if (typeof superClass !== \"function\" && superClass !== null) { throw new TypeError(\"Super expression must either be null or a function\"); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, writable: true, configurable: true } }); if (superClass) _setPrototypeOf(subClass, superClass); }\n\nfunction _setPrototypeOf(o, p) { _setPrototypeOf = Object.setPrototypeOf || function _setPrototypeOf(o, p) { o.__proto__ = p; return o; }; return _setPrototypeOf(o, p); }\n\nfunction _createSuper(Derived) { var hasNativeReflectConstruct = _isNativeReflectConstruct(); return function _createSuperInternal() { var Super = _getPrototypeOf(Derived), result; if (hasNativeReflectConstruct) { var NewTarget = _getPrototypeOf(this).constructor; result = Reflect.construct(Super, arguments, NewTarget); } else { result = Super.apply(this, arguments); } return _possibleConstructorReturn(this, result); }; }\n\nfunction _possibleConstructorReturn(self, call) { if (call && (clipboard_typeof(call) === \"object\" || typeof call === \"function\")) { return call; } return _assertThisInitialized(self); }\n\nfunction _assertThisInitialized(self) { if (self === void 0) { throw new ReferenceError(\"this hasn't been initialised - super() hasn't been called\"); } return self; }\n\nfunction _isNativeReflectConstruct() { if (typeof Reflect === \"undefined\" || !Reflect.construct) return false; if (Reflect.construct.sham) return false; if (typeof Proxy === \"function\") return true; try { Date.prototype.toString.call(Reflect.construct(Date, [], function () {})); return true; } catch (e) { return false; } }\n\nfunction _getPrototypeOf(o) { _getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf : function _getPrototypeOf(o) { return o.__proto__ || Object.getPrototypeOf(o); }; return _getPrototypeOf(o); }\n\n\n\n\n\n\n/**\n * Helper function to retrieve attribute value.\n * @param {String} suffix\n * @param {Element} element\n */\n\nfunction getAttributeValue(suffix, element) {\n var attribute = \"data-clipboard-\".concat(suffix);\n\n if (!element.hasAttribute(attribute)) {\n return;\n }\n\n return element.getAttribute(attribute);\n}\n/**\n * Base class which takes one or more elements, adds event listeners to them,\n * and instantiates a new `ClipboardAction` on each click.\n */\n\n\nvar Clipboard = /*#__PURE__*/function (_Emitter) {\n _inherits(Clipboard, _Emitter);\n\n var _super = _createSuper(Clipboard);\n\n /**\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n * @param {Object} options\n */\n function Clipboard(trigger, options) {\n var _this;\n\n _classCallCheck(this, Clipboard);\n\n _this = _super.call(this);\n\n _this.resolveOptions(options);\n\n _this.listenClick(trigger);\n\n return _this;\n }\n /**\n * Defines if attributes would be resolved using internal setter functions\n * or custom functions that were passed in the constructor.\n * @param {Object} options\n */\n\n\n _createClass(Clipboard, [{\n key: \"resolveOptions\",\n value: function resolveOptions() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n this.action = typeof options.action === 'function' ? options.action : this.defaultAction;\n this.target = typeof options.target === 'function' ? options.target : this.defaultTarget;\n this.text = typeof options.text === 'function' ? options.text : this.defaultText;\n this.container = clipboard_typeof(options.container) === 'object' ? options.container : document.body;\n }\n /**\n * Adds a click event listener to the passed trigger.\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n */\n\n }, {\n key: \"listenClick\",\n value: function listenClick(trigger) {\n var _this2 = this;\n\n this.listener = listen_default()(trigger, 'click', function (e) {\n return _this2.onClick(e);\n });\n }\n /**\n * Defines a new `ClipboardAction` on each click event.\n * @param {Event} e\n */\n\n }, {\n key: \"onClick\",\n value: function onClick(e) {\n var trigger = e.delegateTarget || e.currentTarget;\n var action = this.action(trigger) || 'copy';\n var text = actions_default({\n action: action,\n container: this.container,\n target: this.target(trigger),\n text: this.text(trigger)\n }); // Fires an event based on the copy operation result.\n\n this.emit(text ? 'success' : 'error', {\n action: action,\n text: text,\n trigger: trigger,\n clearSelection: function clearSelection() {\n if (trigger) {\n trigger.focus();\n }\n\n window.getSelection().removeAllRanges();\n }\n });\n }\n /**\n * Default `action` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultAction\",\n value: function defaultAction(trigger) {\n return getAttributeValue('action', trigger);\n }\n /**\n * Default `target` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultTarget\",\n value: function defaultTarget(trigger) {\n var selector = getAttributeValue('target', trigger);\n\n if (selector) {\n return document.querySelector(selector);\n }\n }\n /**\n * Allow fire programmatically a copy action\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @returns Text copied.\n */\n\n }, {\n key: \"defaultText\",\n\n /**\n * Default `text` lookup function.\n * @param {Element} trigger\n */\n value: function defaultText(trigger) {\n return getAttributeValue('text', trigger);\n }\n /**\n * Destroy lifecycle.\n */\n\n }, {\n key: \"destroy\",\n value: function destroy() {\n this.listener.destroy();\n }\n }], [{\n key: \"copy\",\n value: function copy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n return actions_copy(target, options);\n }\n /**\n * Allow fire programmatically a cut action\n * @param {String|HTMLElement} target\n * @returns Text cutted.\n */\n\n }, {\n key: \"cut\",\n value: function cut(target) {\n return actions_cut(target);\n }\n /**\n * Returns the support of the given action, or all actions if no action is\n * given.\n * @param {String} [action]\n */\n\n }, {\n key: \"isSupported\",\n value: function isSupported() {\n var action = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : ['copy', 'cut'];\n var actions = typeof action === 'string' ? [action] : action;\n var support = !!document.queryCommandSupported;\n actions.forEach(function (action) {\n support = support && !!document.queryCommandSupported(action);\n });\n return support;\n }\n }]);\n\n return Clipboard;\n}((tiny_emitter_default()));\n\n/* harmony default export */ var clipboard = (Clipboard);\n\n/***/ }),\n\n/***/ 828:\n/***/ (function(module) {\n\nvar DOCUMENT_NODE_TYPE = 9;\n\n/**\n * A polyfill for Element.matches()\n */\nif (typeof Element !== 'undefined' && !Element.prototype.matches) {\n var proto = Element.prototype;\n\n proto.matches = proto.matchesSelector ||\n proto.mozMatchesSelector ||\n proto.msMatchesSelector ||\n proto.oMatchesSelector ||\n proto.webkitMatchesSelector;\n}\n\n/**\n * Finds the closest parent that matches a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @return {Function}\n */\nfunction closest (element, selector) {\n while (element && element.nodeType !== DOCUMENT_NODE_TYPE) {\n if (typeof element.matches === 'function' &&\n element.matches(selector)) {\n return element;\n }\n element = element.parentNode;\n }\n}\n\nmodule.exports = closest;\n\n\n/***/ }),\n\n/***/ 438:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar closest = __webpack_require__(828);\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction _delegate(element, selector, type, callback, useCapture) {\n var listenerFn = listener.apply(this, arguments);\n\n element.addEventListener(type, listenerFn, useCapture);\n\n return {\n destroy: function() {\n element.removeEventListener(type, listenerFn, useCapture);\n }\n }\n}\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element|String|Array} [elements]\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction delegate(elements, selector, type, callback, useCapture) {\n // Handle the regular Element usage\n if (typeof elements.addEventListener === 'function') {\n return _delegate.apply(null, arguments);\n }\n\n // Handle Element-less usage, it defaults to global delegation\n if (typeof type === 'function') {\n // Use `document` as the first parameter, then apply arguments\n // This is a short way to .unshift `arguments` without running into deoptimizations\n return _delegate.bind(null, document).apply(null, arguments);\n }\n\n // Handle Selector-based usage\n if (typeof elements === 'string') {\n elements = document.querySelectorAll(elements);\n }\n\n // Handle Array-like based usage\n return Array.prototype.map.call(elements, function (element) {\n return _delegate(element, selector, type, callback, useCapture);\n });\n}\n\n/**\n * Finds closest match and invokes callback.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Function}\n */\nfunction listener(element, selector, type, callback) {\n return function(e) {\n e.delegateTarget = closest(e.target, selector);\n\n if (e.delegateTarget) {\n callback.call(element, e);\n }\n }\n}\n\nmodule.exports = delegate;\n\n\n/***/ }),\n\n/***/ 879:\n/***/ (function(__unused_webpack_module, exports) {\n\n/**\n * Check if argument is a HTML element.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.node = function(value) {\n return value !== undefined\n && value instanceof HTMLElement\n && value.nodeType === 1;\n};\n\n/**\n * Check if argument is a list of HTML elements.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.nodeList = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return value !== undefined\n && (type === '[object NodeList]' || type === '[object HTMLCollection]')\n && ('length' in value)\n && (value.length === 0 || exports.node(value[0]));\n};\n\n/**\n * Check if argument is a string.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.string = function(value) {\n return typeof value === 'string'\n || value instanceof String;\n};\n\n/**\n * Check if argument is a function.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.fn = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return type === '[object Function]';\n};\n\n\n/***/ }),\n\n/***/ 370:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar is = __webpack_require__(879);\nvar delegate = __webpack_require__(438);\n\n/**\n * Validates all params and calls the right\n * listener function based on its target type.\n *\n * @param {String|HTMLElement|HTMLCollection|NodeList} target\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listen(target, type, callback) {\n if (!target && !type && !callback) {\n throw new Error('Missing required arguments');\n }\n\n if (!is.string(type)) {\n throw new TypeError('Second argument must be a String');\n }\n\n if (!is.fn(callback)) {\n throw new TypeError('Third argument must be a Function');\n }\n\n if (is.node(target)) {\n return listenNode(target, type, callback);\n }\n else if (is.nodeList(target)) {\n return listenNodeList(target, type, callback);\n }\n else if (is.string(target)) {\n return listenSelector(target, type, callback);\n }\n else {\n throw new TypeError('First argument must be a String, HTMLElement, HTMLCollection, or NodeList');\n }\n}\n\n/**\n * Adds an event listener to a HTML element\n * and returns a remove listener function.\n *\n * @param {HTMLElement} node\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNode(node, type, callback) {\n node.addEventListener(type, callback);\n\n return {\n destroy: function() {\n node.removeEventListener(type, callback);\n }\n }\n}\n\n/**\n * Add an event listener to a list of HTML elements\n * and returns a remove listener function.\n *\n * @param {NodeList|HTMLCollection} nodeList\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNodeList(nodeList, type, callback) {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.addEventListener(type, callback);\n });\n\n return {\n destroy: function() {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.removeEventListener(type, callback);\n });\n }\n }\n}\n\n/**\n * Add an event listener to a selector\n * and returns a remove listener function.\n *\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenSelector(selector, type, callback) {\n return delegate(document.body, selector, type, callback);\n}\n\nmodule.exports = listen;\n\n\n/***/ }),\n\n/***/ 817:\n/***/ (function(module) {\n\nfunction select(element) {\n var selectedText;\n\n if (element.nodeName === 'SELECT') {\n element.focus();\n\n selectedText = element.value;\n }\n else if (element.nodeName === 'INPUT' || element.nodeName === 'TEXTAREA') {\n var isReadOnly = element.hasAttribute('readonly');\n\n if (!isReadOnly) {\n element.setAttribute('readonly', '');\n }\n\n element.select();\n element.setSelectionRange(0, element.value.length);\n\n if (!isReadOnly) {\n element.removeAttribute('readonly');\n }\n\n selectedText = element.value;\n }\n else {\n if (element.hasAttribute('contenteditable')) {\n element.focus();\n }\n\n var selection = window.getSelection();\n var range = document.createRange();\n\n range.selectNodeContents(element);\n selection.removeAllRanges();\n selection.addRange(range);\n\n selectedText = selection.toString();\n }\n\n return selectedText;\n}\n\nmodule.exports = select;\n\n\n/***/ }),\n\n/***/ 279:\n/***/ (function(module) {\n\nfunction E () {\n // Keep this empty so it's easier to inherit from\n // (via https://github.com/lipsmack from https://github.com/scottcorgan/tiny-emitter/issues/3)\n}\n\nE.prototype = {\n on: function (name, callback, ctx) {\n var e = this.e || (this.e = {});\n\n (e[name] || (e[name] = [])).push({\n fn: callback,\n ctx: ctx\n });\n\n return this;\n },\n\n once: function (name, callback, ctx) {\n var self = this;\n function listener () {\n self.off(name, listener);\n callback.apply(ctx, arguments);\n };\n\n listener._ = callback\n return this.on(name, listener, ctx);\n },\n\n emit: function (name) {\n var data = [].slice.call(arguments, 1);\n var evtArr = ((this.e || (this.e = {}))[name] || []).slice();\n var i = 0;\n var len = evtArr.length;\n\n for (i; i < len; i++) {\n evtArr[i].fn.apply(evtArr[i].ctx, data);\n }\n\n return this;\n },\n\n off: function (name, callback) {\n var e = this.e || (this.e = {});\n var evts = e[name];\n var liveEvents = [];\n\n if (evts && callback) {\n for (var i = 0, len = evts.length; i < len; i++) {\n if (evts[i].fn !== callback && evts[i].fn._ !== callback)\n liveEvents.push(evts[i]);\n }\n }\n\n // Remove event from queue to prevent memory leak\n // Suggested by https://github.com/lazd\n // Ref: https://github.com/scottcorgan/tiny-emitter/commit/c6ebfaa9bc973b33d110a84a307742b7cf94c953#commitcomment-5024910\n\n (liveEvents.length)\n ? e[name] = liveEvents\n : delete e[name];\n\n return this;\n }\n};\n\nmodule.exports = E;\nmodule.exports.TinyEmitter = E;\n\n\n/***/ })\n\n/******/ \t});\n/************************************************************************/\n/******/ \t// The module cache\n/******/ \tvar __webpack_module_cache__ = {};\n/******/ \t\n/******/ \t// The require function\n/******/ \tfunction __webpack_require__(moduleId) {\n/******/ \t\t// Check if module is in cache\n/******/ \t\tif(__webpack_module_cache__[moduleId]) {\n/******/ \t\t\treturn __webpack_module_cache__[moduleId].exports;\n/******/ \t\t}\n/******/ \t\t// Create a new module (and put it into the cache)\n/******/ \t\tvar module = __webpack_module_cache__[moduleId] = {\n/******/ \t\t\t// no module.id needed\n/******/ \t\t\t// no module.loaded needed\n/******/ \t\t\texports: {}\n/******/ \t\t};\n/******/ \t\n/******/ \t\t// Execute the module function\n/******/ \t\t__webpack_modules__[moduleId](module, module.exports, __webpack_require__);\n/******/ \t\n/******/ \t\t// Return the exports of the module\n/******/ \t\treturn module.exports;\n/******/ \t}\n/******/ \t\n/************************************************************************/\n/******/ \t/* webpack/runtime/compat get default export */\n/******/ \t!function() {\n/******/ \t\t// getDefaultExport function for compatibility with non-harmony modules\n/******/ \t\t__webpack_require__.n = function(module) {\n/******/ \t\t\tvar getter = module && module.__esModule ?\n/******/ \t\t\t\tfunction() { return module['default']; } :\n/******/ \t\t\t\tfunction() { return module; };\n/******/ \t\t\t__webpack_require__.d(getter, { a: getter });\n/******/ \t\t\treturn getter;\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/define property getters */\n/******/ \t!function() {\n/******/ \t\t// define getter functions for harmony exports\n/******/ \t\t__webpack_require__.d = function(exports, definition) {\n/******/ \t\t\tfor(var key in definition) {\n/******/ \t\t\t\tif(__webpack_require__.o(definition, key) && !__webpack_require__.o(exports, key)) {\n/******/ \t\t\t\t\tObject.defineProperty(exports, key, { enumerable: true, get: definition[key] });\n/******/ \t\t\t\t}\n/******/ \t\t\t}\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/hasOwnProperty shorthand */\n/******/ \t!function() {\n/******/ \t\t__webpack_require__.o = function(obj, prop) { return Object.prototype.hasOwnProperty.call(obj, prop); }\n/******/ \t}();\n/******/ \t\n/************************************************************************/\n/******/ \t// module exports must be returned from runtime so entry inlining is disabled\n/******/ \t// startup\n/******/ \t// Load entry module and return exports\n/******/ \treturn __webpack_require__(686);\n/******/ })()\n.default;\n});", "/*\n * Copyright (c) 2016-2024 Martin Donath \n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\nimport \"focus-visible\"\n\nimport {\n EMPTY,\n NEVER,\n Observable,\n Subject,\n defer,\n delay,\n filter,\n map,\n merge,\n mergeWith,\n shareReplay,\n switchMap\n} from \"rxjs\"\n\nimport { configuration, feature } from \"./_\"\nimport {\n at,\n getActiveElement,\n getOptionalElement,\n requestJSON,\n setLocation,\n setToggle,\n watchDocument,\n watchKeyboard,\n watchLocation,\n watchLocationTarget,\n watchMedia,\n watchPrint,\n watchScript,\n watchViewport\n} from \"./browser\"\nimport {\n getComponentElement,\n getComponentElements,\n mountAnnounce,\n mountBackToTop,\n mountConsent,\n mountContent,\n mountDialog,\n mountHeader,\n mountHeaderTitle,\n mountPalette,\n mountProgress,\n mountSearch,\n mountSearchHiglight,\n mountSidebar,\n mountSource,\n mountTableOfContents,\n mountTabs,\n watchHeader,\n watchMain\n} from \"./components\"\nimport {\n SearchIndex,\n setupClipboardJS,\n setupInstantNavigation,\n setupVersionSelector\n} from \"./integrations\"\nimport {\n patchEllipsis,\n patchIndeterminate,\n patchScrollfix,\n patchScrolllock\n} from \"./patches\"\nimport \"./polyfills\"\n\n/* ----------------------------------------------------------------------------\n * Functions - @todo refactor\n * ------------------------------------------------------------------------- */\n\n/**\n * Fetch search index\n *\n * @returns Search index observable\n */\nfunction fetchSearchIndex(): Observable {\n if (location.protocol === \"file:\") {\n return watchScript(\n `${new URL(\"search/search_index.js\", config.base)}`\n )\n .pipe(\n // @ts-ignore - @todo fix typings\n map(() => __index),\n shareReplay(1)\n )\n } else {\n return requestJSON(\n new URL(\"search/search_index.json\", config.base)\n )\n }\n}\n\n/* ----------------------------------------------------------------------------\n * Application\n * ------------------------------------------------------------------------- */\n\n/* Yay, JavaScript is available */\ndocument.documentElement.classList.remove(\"no-js\")\ndocument.documentElement.classList.add(\"js\")\n\n/* Set up navigation observables and subjects */\nconst document$ = watchDocument()\nconst location$ = watchLocation()\nconst target$ = watchLocationTarget(location$)\nconst keyboard$ = watchKeyboard()\n\n/* Set up media observables */\nconst viewport$ = watchViewport()\nconst tablet$ = watchMedia(\"(min-width: 960px)\")\nconst screen$ = watchMedia(\"(min-width: 1220px)\")\nconst print$ = watchPrint()\n\n/* Retrieve search index, if search is enabled */\nconst config = configuration()\nconst index$ = document.forms.namedItem(\"search\")\n ? fetchSearchIndex()\n : NEVER\n\n/* Set up Clipboard.js integration */\nconst alert$ = new Subject()\nsetupClipboardJS({ alert$ })\n\n/* Set up progress indicator */\nconst progress$ = new Subject()\n\n/* Set up instant navigation, if enabled */\nif (feature(\"navigation.instant\"))\n setupInstantNavigation({ location$, viewport$, progress$ })\n .subscribe(document$)\n\n/* Set up version selector */\nif (config.version?.provider === \"mike\")\n setupVersionSelector({ document$ })\n\n/* Always close drawer and search on navigation */\nmerge(location$, target$)\n .pipe(\n delay(125)\n )\n .subscribe(() => {\n setToggle(\"drawer\", false)\n setToggle(\"search\", false)\n })\n\n/* Set up global keyboard handlers */\nkeyboard$\n .pipe(\n filter(({ mode }) => mode === \"global\")\n )\n .subscribe(key => {\n switch (key.type) {\n\n /* Go to previous page */\n case \"p\":\n case \",\":\n const prev = getOptionalElement(\"link[rel=prev]\")\n if (typeof prev !== \"undefined\")\n setLocation(prev)\n break\n\n /* Go to next page */\n case \"n\":\n case \".\":\n const next = getOptionalElement(\"link[rel=next]\")\n if (typeof next !== \"undefined\")\n setLocation(next)\n break\n\n /* Expand navigation, see https://bit.ly/3ZjG5io */\n case \"Enter\":\n const active = getActiveElement()\n if (active instanceof HTMLLabelElement)\n active.click()\n }\n })\n\n/* Set up patches */\npatchEllipsis({ viewport$, document$ })\npatchIndeterminate({ document$, tablet$ })\npatchScrollfix({ document$ })\npatchScrolllock({ viewport$, tablet$ })\n\n/* Set up header and main area observable */\nconst header$ = watchHeader(getComponentElement(\"header\"), { viewport$ })\nconst main$ = document$\n .pipe(\n map(() => getComponentElement(\"main\")),\n switchMap(el => watchMain(el, { viewport$, header$ })),\n shareReplay(1)\n )\n\n/* Set up control component observables */\nconst control$ = merge(\n\n /* Consent */\n ...getComponentElements(\"consent\")\n .map(el => mountConsent(el, { target$ })),\n\n /* Dialog */\n ...getComponentElements(\"dialog\")\n .map(el => mountDialog(el, { alert$ })),\n\n /* Color palette */\n ...getComponentElements(\"palette\")\n .map(el => mountPalette(el)),\n\n /* Progress bar */\n ...getComponentElements(\"progress\")\n .map(el => mountProgress(el, { progress$ })),\n\n /* Search */\n ...getComponentElements(\"search\")\n .map(el => mountSearch(el, { index$, keyboard$ })),\n\n /* Repository information */\n ...getComponentElements(\"source\")\n .map(el => mountSource(el))\n)\n\n/* Set up content component observables */\nconst content$ = defer(() => merge(\n\n /* Announcement bar */\n ...getComponentElements(\"announce\")\n .map(el => mountAnnounce(el)),\n\n /* Content */\n ...getComponentElements(\"content\")\n .map(el => mountContent(el, { viewport$, target$, print$ })),\n\n /* Search highlighting */\n ...getComponentElements(\"content\")\n .map(el => feature(\"search.highlight\")\n ? mountSearchHiglight(el, { index$, location$ })\n : EMPTY\n ),\n\n /* Header */\n ...getComponentElements(\"header\")\n .map(el => mountHeader(el, { viewport$, header$, main$ })),\n\n /* Header title */\n ...getComponentElements(\"header-title\")\n .map(el => mountHeaderTitle(el, { viewport$, header$ })),\n\n /* Sidebar */\n ...getComponentElements(\"sidebar\")\n .map(el => el.getAttribute(\"data-md-type\") === \"navigation\"\n ? at(screen$, () => mountSidebar(el, { viewport$, header$, main$ }))\n : at(tablet$, () => mountSidebar(el, { viewport$, header$, main$ }))\n ),\n\n /* Navigation tabs */\n ...getComponentElements(\"tabs\")\n .map(el => mountTabs(el, { viewport$, header$ })),\n\n /* Table of contents */\n ...getComponentElements(\"toc\")\n .map(el => mountTableOfContents(el, {\n viewport$, header$, main$, target$\n })),\n\n /* Back-to-top button */\n ...getComponentElements(\"top\")\n .map(el => mountBackToTop(el, { viewport$, header$, main$, target$ }))\n))\n\n/* Set up component observables */\nconst component$ = document$\n .pipe(\n switchMap(() => content$),\n mergeWith(control$),\n shareReplay(1)\n )\n\n/* Subscribe to all components */\ncomponent$.subscribe()\n\n/* ----------------------------------------------------------------------------\n * Exports\n * ------------------------------------------------------------------------- */\n\nwindow.document$ = document$ /* Document observable */\nwindow.location$ = location$ /* Location subject */\nwindow.target$ = target$ /* Location target observable */\nwindow.keyboard$ = keyboard$ /* Keyboard observable */\nwindow.viewport$ = viewport$ /* Viewport observable */\nwindow.tablet$ = tablet$ /* Media tablet observable */\nwindow.screen$ = screen$ /* Media screen observable */\nwindow.print$ = print$ /* Media print observable */\nwindow.alert$ = alert$ /* Alert subject */\nwindow.progress$ = progress$ /* Progress indicator subject */\nwindow.component$ = component$ /* Component observable */\n", "/******************************************************************************\nCopyright (c) Microsoft Corporation.\n\nPermission to use, copy, modify, and/or distribute this software for any\npurpose with or without fee is hereby granted.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH\nREGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY\nAND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,\nINDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM\nLOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR\nOTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR\nPERFORMANCE OF THIS SOFTWARE.\n***************************************************************************** */\n/* global Reflect, Promise, SuppressedError, Symbol, Iterator */\n\nvar extendStatics = function(d, b) {\n extendStatics = Object.setPrototypeOf ||\n ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||\n function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; };\n return extendStatics(d, b);\n};\n\nexport function __extends(d, b) {\n if (typeof b !== \"function\" && b !== null)\n throw new TypeError(\"Class extends value \" + String(b) + \" is not a constructor or null\");\n extendStatics(d, b);\n function __() { this.constructor = d; }\n d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());\n}\n\nexport var __assign = function() {\n __assign = Object.assign || function __assign(t) {\n for (var s, i = 1, n = arguments.length; i < n; i++) {\n s = arguments[i];\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p)) t[p] = s[p];\n }\n return t;\n }\n return __assign.apply(this, arguments);\n}\n\nexport function __rest(s, e) {\n var t = {};\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)\n t[p] = s[p];\n if (s != null && typeof Object.getOwnPropertySymbols === \"function\")\n for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {\n if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))\n t[p[i]] = s[p[i]];\n }\n return t;\n}\n\nexport function __decorate(decorators, target, key, desc) {\n var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;\n if (typeof Reflect === \"object\" && typeof Reflect.decorate === \"function\") r = Reflect.decorate(decorators, target, key, desc);\n else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;\n return c > 3 && r && Object.defineProperty(target, key, r), r;\n}\n\nexport function __param(paramIndex, decorator) {\n return function (target, key) { decorator(target, key, paramIndex); }\n}\n\nexport function __esDecorate(ctor, descriptorIn, decorators, contextIn, initializers, extraInitializers) {\n function accept(f) { if (f !== void 0 && typeof f !== \"function\") throw new TypeError(\"Function expected\"); return f; }\n var kind = contextIn.kind, key = kind === \"getter\" ? \"get\" : kind === \"setter\" ? \"set\" : \"value\";\n var target = !descriptorIn && ctor ? contextIn[\"static\"] ? ctor : ctor.prototype : null;\n var descriptor = descriptorIn || (target ? Object.getOwnPropertyDescriptor(target, contextIn.name) : {});\n var _, done = false;\n for (var i = decorators.length - 1; i >= 0; i--) {\n var context = {};\n for (var p in contextIn) context[p] = p === \"access\" ? {} : contextIn[p];\n for (var p in contextIn.access) context.access[p] = contextIn.access[p];\n context.addInitializer = function (f) { if (done) throw new TypeError(\"Cannot add initializers after decoration has completed\"); extraInitializers.push(accept(f || null)); };\n var result = (0, decorators[i])(kind === \"accessor\" ? { get: descriptor.get, set: descriptor.set } : descriptor[key], context);\n if (kind === \"accessor\") {\n if (result === void 0) continue;\n if (result === null || typeof result !== \"object\") throw new TypeError(\"Object expected\");\n if (_ = accept(result.get)) descriptor.get = _;\n if (_ = accept(result.set)) descriptor.set = _;\n if (_ = accept(result.init)) initializers.unshift(_);\n }\n else if (_ = accept(result)) {\n if (kind === \"field\") initializers.unshift(_);\n else descriptor[key] = _;\n }\n }\n if (target) Object.defineProperty(target, contextIn.name, descriptor);\n done = true;\n};\n\nexport function __runInitializers(thisArg, initializers, value) {\n var useValue = arguments.length > 2;\n for (var i = 0; i < initializers.length; i++) {\n value = useValue ? initializers[i].call(thisArg, value) : initializers[i].call(thisArg);\n }\n return useValue ? value : void 0;\n};\n\nexport function __propKey(x) {\n return typeof x === \"symbol\" ? x : \"\".concat(x);\n};\n\nexport function __setFunctionName(f, name, prefix) {\n if (typeof name === \"symbol\") name = name.description ? \"[\".concat(name.description, \"]\") : \"\";\n return Object.defineProperty(f, \"name\", { configurable: true, value: prefix ? \"\".concat(prefix, \" \", name) : name });\n};\n\nexport function __metadata(metadataKey, metadataValue) {\n if (typeof Reflect === \"object\" && typeof Reflect.metadata === \"function\") return Reflect.metadata(metadataKey, metadataValue);\n}\n\nexport function __awaiter(thisArg, _arguments, P, generator) {\n function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }\n return new (P || (P = Promise))(function (resolve, reject) {\n function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }\n function rejected(value) { try { step(generator[\"throw\"](value)); } catch (e) { reject(e); } }\n function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }\n step((generator = generator.apply(thisArg, _arguments || [])).next());\n });\n}\n\nexport function __generator(thisArg, body) {\n var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g = Object.create((typeof Iterator === \"function\" ? Iterator : Object).prototype);\n return g.next = verb(0), g[\"throw\"] = verb(1), g[\"return\"] = verb(2), typeof Symbol === \"function\" && (g[Symbol.iterator] = function() { return this; }), g;\n function verb(n) { return function (v) { return step([n, v]); }; }\n function step(op) {\n if (f) throw new TypeError(\"Generator is already executing.\");\n while (g && (g = 0, op[0] && (_ = 0)), _) try {\n if (f = 1, y && (t = op[0] & 2 ? y[\"return\"] : op[0] ? y[\"throw\"] || ((t = y[\"return\"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;\n if (y = 0, t) op = [op[0] & 2, t.value];\n switch (op[0]) {\n case 0: case 1: t = op; break;\n case 4: _.label++; return { value: op[1], done: false };\n case 5: _.label++; y = op[1]; op = [0]; continue;\n case 7: op = _.ops.pop(); _.trys.pop(); continue;\n default:\n if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }\n if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }\n if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }\n if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }\n if (t[2]) _.ops.pop();\n _.trys.pop(); continue;\n }\n op = body.call(thisArg, _);\n } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }\n if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };\n }\n}\n\nexport var __createBinding = Object.create ? (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n var desc = Object.getOwnPropertyDescriptor(m, k);\n if (!desc || (\"get\" in desc ? !m.__esModule : desc.writable || desc.configurable)) {\n desc = { enumerable: true, get: function() { return m[k]; } };\n }\n Object.defineProperty(o, k2, desc);\n}) : (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n o[k2] = m[k];\n});\n\nexport function __exportStar(m, o) {\n for (var p in m) if (p !== \"default\" && !Object.prototype.hasOwnProperty.call(o, p)) __createBinding(o, m, p);\n}\n\nexport function __values(o) {\n var s = typeof Symbol === \"function\" && Symbol.iterator, m = s && o[s], i = 0;\n if (m) return m.call(o);\n if (o && typeof o.length === \"number\") return {\n next: function () {\n if (o && i >= o.length) o = void 0;\n return { value: o && o[i++], done: !o };\n }\n };\n throw new TypeError(s ? \"Object is not iterable.\" : \"Symbol.iterator is not defined.\");\n}\n\nexport function __read(o, n) {\n var m = typeof Symbol === \"function\" && o[Symbol.iterator];\n if (!m) return o;\n var i = m.call(o), r, ar = [], e;\n try {\n while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value);\n }\n catch (error) { e = { error: error }; }\n finally {\n try {\n if (r && !r.done && (m = i[\"return\"])) m.call(i);\n }\n finally { if (e) throw e.error; }\n }\n return ar;\n}\n\n/** @deprecated */\nexport function __spread() {\n for (var ar = [], i = 0; i < arguments.length; i++)\n ar = ar.concat(__read(arguments[i]));\n return ar;\n}\n\n/** @deprecated */\nexport function __spreadArrays() {\n for (var s = 0, i = 0, il = arguments.length; i < il; i++) s += arguments[i].length;\n for (var r = Array(s), k = 0, i = 0; i < il; i++)\n for (var a = arguments[i], j = 0, jl = a.length; j < jl; j++, k++)\n r[k] = a[j];\n return r;\n}\n\nexport function __spreadArray(to, from, pack) {\n if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) {\n if (ar || !(i in from)) {\n if (!ar) ar = Array.prototype.slice.call(from, 0, i);\n ar[i] = from[i];\n }\n }\n return to.concat(ar || Array.prototype.slice.call(from));\n}\n\nexport function __await(v) {\n return this instanceof __await ? (this.v = v, this) : new __await(v);\n}\n\nexport function __asyncGenerator(thisArg, _arguments, generator) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var g = generator.apply(thisArg, _arguments || []), i, q = [];\n return i = Object.create((typeof AsyncIterator === \"function\" ? AsyncIterator : Object).prototype), verb(\"next\"), verb(\"throw\"), verb(\"return\", awaitReturn), i[Symbol.asyncIterator] = function () { return this; }, i;\n function awaitReturn(f) { return function (v) { return Promise.resolve(v).then(f, reject); }; }\n function verb(n, f) { if (g[n]) { i[n] = function (v) { return new Promise(function (a, b) { q.push([n, v, a, b]) > 1 || resume(n, v); }); }; if (f) i[n] = f(i[n]); } }\n function resume(n, v) { try { step(g[n](v)); } catch (e) { settle(q[0][3], e); } }\n function step(r) { r.value instanceof __await ? Promise.resolve(r.value.v).then(fulfill, reject) : settle(q[0][2], r); }\n function fulfill(value) { resume(\"next\", value); }\n function reject(value) { resume(\"throw\", value); }\n function settle(f, v) { if (f(v), q.shift(), q.length) resume(q[0][0], q[0][1]); }\n}\n\nexport function __asyncDelegator(o) {\n var i, p;\n return i = {}, verb(\"next\"), verb(\"throw\", function (e) { throw e; }), verb(\"return\"), i[Symbol.iterator] = function () { return this; }, i;\n function verb(n, f) { i[n] = o[n] ? function (v) { return (p = !p) ? { value: __await(o[n](v)), done: false } : f ? f(v) : v; } : f; }\n}\n\nexport function __asyncValues(o) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var m = o[Symbol.asyncIterator], i;\n return m ? m.call(o) : (o = typeof __values === \"function\" ? __values(o) : o[Symbol.iterator](), i = {}, verb(\"next\"), verb(\"throw\"), verb(\"return\"), i[Symbol.asyncIterator] = function () { return this; }, i);\n function verb(n) { i[n] = o[n] && function (v) { return new Promise(function (resolve, reject) { v = o[n](v), settle(resolve, reject, v.done, v.value); }); }; }\n function settle(resolve, reject, d, v) { Promise.resolve(v).then(function(v) { resolve({ value: v, done: d }); }, reject); }\n}\n\nexport function __makeTemplateObject(cooked, raw) {\n if (Object.defineProperty) { Object.defineProperty(cooked, \"raw\", { value: raw }); } else { cooked.raw = raw; }\n return cooked;\n};\n\nvar __setModuleDefault = Object.create ? (function(o, v) {\n Object.defineProperty(o, \"default\", { enumerable: true, value: v });\n}) : function(o, v) {\n o[\"default\"] = v;\n};\n\nexport function __importStar(mod) {\n if (mod && mod.__esModule) return mod;\n var result = {};\n if (mod != null) for (var k in mod) if (k !== \"default\" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);\n __setModuleDefault(result, mod);\n return result;\n}\n\nexport function __importDefault(mod) {\n return (mod && mod.__esModule) ? mod : { default: mod };\n}\n\nexport function __classPrivateFieldGet(receiver, state, kind, f) {\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a getter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot read private member from an object whose class did not declare it\");\n return kind === \"m\" ? f : kind === \"a\" ? f.call(receiver) : f ? f.value : state.get(receiver);\n}\n\nexport function __classPrivateFieldSet(receiver, state, value, kind, f) {\n if (kind === \"m\") throw new TypeError(\"Private method is not writable\");\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a setter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot write private member to an object whose class did not declare it\");\n return (kind === \"a\" ? f.call(receiver, value) : f ? f.value = value : state.set(receiver, value)), value;\n}\n\nexport function __classPrivateFieldIn(state, receiver) {\n if (receiver === null || (typeof receiver !== \"object\" && typeof receiver !== \"function\")) throw new TypeError(\"Cannot use 'in' operator on non-object\");\n return typeof state === \"function\" ? receiver === state : state.has(receiver);\n}\n\nexport function __addDisposableResource(env, value, async) {\n if (value !== null && value !== void 0) {\n if (typeof value !== \"object\" && typeof value !== \"function\") throw new TypeError(\"Object expected.\");\n var dispose, inner;\n if (async) {\n if (!Symbol.asyncDispose) throw new TypeError(\"Symbol.asyncDispose is not defined.\");\n dispose = value[Symbol.asyncDispose];\n }\n if (dispose === void 0) {\n if (!Symbol.dispose) throw new TypeError(\"Symbol.dispose is not defined.\");\n dispose = value[Symbol.dispose];\n if (async) inner = dispose;\n }\n if (typeof dispose !== \"function\") throw new TypeError(\"Object not disposable.\");\n if (inner) dispose = function() { try { inner.call(this); } catch (e) { return Promise.reject(e); } };\n env.stack.push({ value: value, dispose: dispose, async: async });\n }\n else if (async) {\n env.stack.push({ async: true });\n }\n return value;\n}\n\nvar _SuppressedError = typeof SuppressedError === \"function\" ? SuppressedError : function (error, suppressed, message) {\n var e = new Error(message);\n return e.name = \"SuppressedError\", e.error = error, e.suppressed = suppressed, e;\n};\n\nexport function __disposeResources(env) {\n function fail(e) {\n env.error = env.hasError ? new _SuppressedError(e, env.error, \"An error was suppressed during disposal.\") : e;\n env.hasError = true;\n }\n var r, s = 0;\n function next() {\n while (r = env.stack.pop()) {\n try {\n if (!r.async && s === 1) return s = 0, env.stack.push(r), Promise.resolve().then(next);\n if (r.dispose) {\n var result = r.dispose.call(r.value);\n if (r.async) return s |= 2, Promise.resolve(result).then(next, function(e) { fail(e); return next(); });\n }\n else s |= 1;\n }\n catch (e) {\n fail(e);\n }\n }\n if (s === 1) return env.hasError ? Promise.reject(env.error) : Promise.resolve();\n if (env.hasError) throw env.error;\n }\n return next();\n}\n\nexport default {\n __extends,\n __assign,\n __rest,\n __decorate,\n __param,\n __metadata,\n __awaiter,\n __generator,\n __createBinding,\n __exportStar,\n __values,\n __read,\n __spread,\n __spreadArrays,\n __spreadArray,\n __await,\n __asyncGenerator,\n __asyncDelegator,\n __asyncValues,\n __makeTemplateObject,\n __importStar,\n __importDefault,\n __classPrivateFieldGet,\n __classPrivateFieldSet,\n __classPrivateFieldIn,\n __addDisposableResource,\n __disposeResources,\n};\n", "/**\n * Returns true if the object is a function.\n * @param value The value to check\n */\nexport function isFunction(value: any): value is (...args: any[]) => any {\n return typeof value === 'function';\n}\n", "/**\n * Used to create Error subclasses until the community moves away from ES5.\n *\n * This is because compiling from TypeScript down to ES5 has issues with subclassing Errors\n * as well as other built-in types: https://github.com/Microsoft/TypeScript/issues/12123\n *\n * @param createImpl A factory function to create the actual constructor implementation. The returned\n * function should be a named function that calls `_super` internally.\n */\nexport function createErrorClass(createImpl: (_super: any) => any): T {\n const _super = (instance: any) => {\n Error.call(instance);\n instance.stack = new Error().stack;\n };\n\n const ctorFunc = createImpl(_super);\n ctorFunc.prototype = Object.create(Error.prototype);\n ctorFunc.prototype.constructor = ctorFunc;\n return ctorFunc;\n}\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface UnsubscriptionError extends Error {\n readonly errors: any[];\n}\n\nexport interface UnsubscriptionErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (errors: any[]): UnsubscriptionError;\n}\n\n/**\n * An error thrown when one or more errors have occurred during the\n * `unsubscribe` of a {@link Subscription}.\n */\nexport const UnsubscriptionError: UnsubscriptionErrorCtor = createErrorClass(\n (_super) =>\n function UnsubscriptionErrorImpl(this: any, errors: (Error | string)[]) {\n _super(this);\n this.message = errors\n ? `${errors.length} errors occurred during unsubscription:\n${errors.map((err, i) => `${i + 1}) ${err.toString()}`).join('\\n ')}`\n : '';\n this.name = 'UnsubscriptionError';\n this.errors = errors;\n }\n);\n", "/**\n * Removes an item from an array, mutating it.\n * @param arr The array to remove the item from\n * @param item The item to remove\n */\nexport function arrRemove(arr: T[] | undefined | null, item: T) {\n if (arr) {\n const index = arr.indexOf(item);\n 0 <= index && arr.splice(index, 1);\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { UnsubscriptionError } from './util/UnsubscriptionError';\nimport { SubscriptionLike, TeardownLogic, Unsubscribable } from './types';\nimport { arrRemove } from './util/arrRemove';\n\n/**\n * Represents a disposable resource, such as the execution of an Observable. A\n * Subscription has one important method, `unsubscribe`, that takes no argument\n * and just disposes the resource held by the subscription.\n *\n * Additionally, subscriptions may be grouped together through the `add()`\n * method, which will attach a child Subscription to the current Subscription.\n * When a Subscription is unsubscribed, all its children (and its grandchildren)\n * will be unsubscribed as well.\n *\n * @class Subscription\n */\nexport class Subscription implements SubscriptionLike {\n /** @nocollapse */\n public static EMPTY = (() => {\n const empty = new Subscription();\n empty.closed = true;\n return empty;\n })();\n\n /**\n * A flag to indicate whether this Subscription has already been unsubscribed.\n */\n public closed = false;\n\n private _parentage: Subscription[] | Subscription | null = null;\n\n /**\n * The list of registered finalizers to execute upon unsubscription. Adding and removing from this\n * list occurs in the {@link #add} and {@link #remove} methods.\n */\n private _finalizers: Exclude[] | null = null;\n\n /**\n * @param initialTeardown A function executed first as part of the finalization\n * process that is kicked off when {@link #unsubscribe} is called.\n */\n constructor(private initialTeardown?: () => void) {}\n\n /**\n * Disposes the resources held by the subscription. May, for instance, cancel\n * an ongoing Observable execution or cancel any other type of work that\n * started when the Subscription was created.\n * @return {void}\n */\n unsubscribe(): void {\n let errors: any[] | undefined;\n\n if (!this.closed) {\n this.closed = true;\n\n // Remove this from it's parents.\n const { _parentage } = this;\n if (_parentage) {\n this._parentage = null;\n if (Array.isArray(_parentage)) {\n for (const parent of _parentage) {\n parent.remove(this);\n }\n } else {\n _parentage.remove(this);\n }\n }\n\n const { initialTeardown: initialFinalizer } = this;\n if (isFunction(initialFinalizer)) {\n try {\n initialFinalizer();\n } catch (e) {\n errors = e instanceof UnsubscriptionError ? e.errors : [e];\n }\n }\n\n const { _finalizers } = this;\n if (_finalizers) {\n this._finalizers = null;\n for (const finalizer of _finalizers) {\n try {\n execFinalizer(finalizer);\n } catch (err) {\n errors = errors ?? [];\n if (err instanceof UnsubscriptionError) {\n errors = [...errors, ...err.errors];\n } else {\n errors.push(err);\n }\n }\n }\n }\n\n if (errors) {\n throw new UnsubscriptionError(errors);\n }\n }\n }\n\n /**\n * Adds a finalizer to this subscription, so that finalization will be unsubscribed/called\n * when this subscription is unsubscribed. If this subscription is already {@link #closed},\n * because it has already been unsubscribed, then whatever finalizer is passed to it\n * will automatically be executed (unless the finalizer itself is also a closed subscription).\n *\n * Closed Subscriptions cannot be added as finalizers to any subscription. Adding a closed\n * subscription to a any subscription will result in no operation. (A noop).\n *\n * Adding a subscription to itself, or adding `null` or `undefined` will not perform any\n * operation at all. (A noop).\n *\n * `Subscription` instances that are added to this instance will automatically remove themselves\n * if they are unsubscribed. Functions and {@link Unsubscribable} objects that you wish to remove\n * will need to be removed manually with {@link #remove}\n *\n * @param teardown The finalization logic to add to this subscription.\n */\n add(teardown: TeardownLogic): void {\n // Only add the finalizer if it's not undefined\n // and don't add a subscription to itself.\n if (teardown && teardown !== this) {\n if (this.closed) {\n // If this subscription is already closed,\n // execute whatever finalizer is handed to it automatically.\n execFinalizer(teardown);\n } else {\n if (teardown instanceof Subscription) {\n // We don't add closed subscriptions, and we don't add the same subscription\n // twice. Subscription unsubscribe is idempotent.\n if (teardown.closed || teardown._hasParent(this)) {\n return;\n }\n teardown._addParent(this);\n }\n (this._finalizers = this._finalizers ?? []).push(teardown);\n }\n }\n }\n\n /**\n * Checks to see if a this subscription already has a particular parent.\n * This will signal that this subscription has already been added to the parent in question.\n * @param parent the parent to check for\n */\n private _hasParent(parent: Subscription) {\n const { _parentage } = this;\n return _parentage === parent || (Array.isArray(_parentage) && _parentage.includes(parent));\n }\n\n /**\n * Adds a parent to this subscription so it can be removed from the parent if it\n * unsubscribes on it's own.\n *\n * NOTE: THIS ASSUMES THAT {@link _hasParent} HAS ALREADY BEEN CHECKED.\n * @param parent The parent subscription to add\n */\n private _addParent(parent: Subscription) {\n const { _parentage } = this;\n this._parentage = Array.isArray(_parentage) ? (_parentage.push(parent), _parentage) : _parentage ? [_parentage, parent] : parent;\n }\n\n /**\n * Called on a child when it is removed via {@link #remove}.\n * @param parent The parent to remove\n */\n private _removeParent(parent: Subscription) {\n const { _parentage } = this;\n if (_parentage === parent) {\n this._parentage = null;\n } else if (Array.isArray(_parentage)) {\n arrRemove(_parentage, parent);\n }\n }\n\n /**\n * Removes a finalizer from this subscription that was previously added with the {@link #add} method.\n *\n * Note that `Subscription` instances, when unsubscribed, will automatically remove themselves\n * from every other `Subscription` they have been added to. This means that using the `remove` method\n * is not a common thing and should be used thoughtfully.\n *\n * If you add the same finalizer instance of a function or an unsubscribable object to a `Subscription` instance\n * more than once, you will need to call `remove` the same number of times to remove all instances.\n *\n * All finalizer instances are removed to free up memory upon unsubscription.\n *\n * @param teardown The finalizer to remove from this subscription\n */\n remove(teardown: Exclude): void {\n const { _finalizers } = this;\n _finalizers && arrRemove(_finalizers, teardown);\n\n if (teardown instanceof Subscription) {\n teardown._removeParent(this);\n }\n }\n}\n\nexport const EMPTY_SUBSCRIPTION = Subscription.EMPTY;\n\nexport function isSubscription(value: any): value is Subscription {\n return (\n value instanceof Subscription ||\n (value && 'closed' in value && isFunction(value.remove) && isFunction(value.add) && isFunction(value.unsubscribe))\n );\n}\n\nfunction execFinalizer(finalizer: Unsubscribable | (() => void)) {\n if (isFunction(finalizer)) {\n finalizer();\n } else {\n finalizer.unsubscribe();\n }\n}\n", "import { Subscriber } from './Subscriber';\nimport { ObservableNotification } from './types';\n\n/**\n * The {@link GlobalConfig} object for RxJS. It is used to configure things\n * like how to react on unhandled errors.\n */\nexport const config: GlobalConfig = {\n onUnhandledError: null,\n onStoppedNotification: null,\n Promise: undefined,\n useDeprecatedSynchronousErrorHandling: false,\n useDeprecatedNextContext: false,\n};\n\n/**\n * The global configuration object for RxJS, used to configure things\n * like how to react on unhandled errors. Accessible via {@link config}\n * object.\n */\nexport interface GlobalConfig {\n /**\n * A registration point for unhandled errors from RxJS. These are errors that\n * cannot were not handled by consuming code in the usual subscription path. For\n * example, if you have this configured, and you subscribe to an observable without\n * providing an error handler, errors from that subscription will end up here. This\n * will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onUnhandledError: ((err: any) => void) | null;\n\n /**\n * A registration point for notifications that cannot be sent to subscribers because they\n * have completed, errored or have been explicitly unsubscribed. By default, next, complete\n * and error notifications sent to stopped subscribers are noops. However, sometimes callers\n * might want a different behavior. For example, with sources that attempt to report errors\n * to stopped subscribers, a caller can configure RxJS to throw an unhandled error instead.\n * This will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onStoppedNotification: ((notification: ObservableNotification, subscriber: Subscriber) => void) | null;\n\n /**\n * The promise constructor used by default for {@link Observable#toPromise toPromise} and {@link Observable#forEach forEach}\n * methods.\n *\n * @deprecated As of version 8, RxJS will no longer support this sort of injection of a\n * Promise constructor. If you need a Promise implementation other than native promises,\n * please polyfill/patch Promise as you see appropriate. Will be removed in v8.\n */\n Promise?: PromiseConstructorLike;\n\n /**\n * If true, turns on synchronous error rethrowing, which is a deprecated behavior\n * in v6 and higher. This behavior enables bad patterns like wrapping a subscribe\n * call in a try/catch block. It also enables producer interference, a nasty bug\n * where a multicast can be broken for all observers by a downstream consumer with\n * an unhandled error. DO NOT USE THIS FLAG UNLESS IT'S NEEDED TO BUY TIME\n * FOR MIGRATION REASONS.\n *\n * @deprecated As of version 8, RxJS will no longer support synchronous throwing\n * of unhandled errors. All errors will be thrown on a separate call stack to prevent bad\n * behaviors described above. Will be removed in v8.\n */\n useDeprecatedSynchronousErrorHandling: boolean;\n\n /**\n * If true, enables an as-of-yet undocumented feature from v5: The ability to access\n * `unsubscribe()` via `this` context in `next` functions created in observers passed\n * to `subscribe`.\n *\n * This is being removed because the performance was severely problematic, and it could also cause\n * issues when types other than POJOs are passed to subscribe as subscribers, as they will likely have\n * their `this` context overwritten.\n *\n * @deprecated As of version 8, RxJS will no longer support altering the\n * context of next functions provided as part of an observer to Subscribe. Instead,\n * you will have access to a subscription or a signal or token that will allow you to do things like\n * unsubscribe and test closed status. Will be removed in v8.\n */\n useDeprecatedNextContext: boolean;\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetTimeoutFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearTimeoutFunction = (handle: TimerHandle) => void;\n\ninterface TimeoutProvider {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n delegate:\n | {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n }\n | undefined;\n}\n\nexport const timeoutProvider: TimeoutProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setTimeout(handler: () => void, timeout?: number, ...args) {\n const { delegate } = timeoutProvider;\n if (delegate?.setTimeout) {\n return delegate.setTimeout(handler, timeout, ...args);\n }\n return setTimeout(handler, timeout, ...args);\n },\n clearTimeout(handle) {\n const { delegate } = timeoutProvider;\n return (delegate?.clearTimeout || clearTimeout)(handle as any);\n },\n delegate: undefined,\n};\n", "import { config } from '../config';\nimport { timeoutProvider } from '../scheduler/timeoutProvider';\n\n/**\n * Handles an error on another job either with the user-configured {@link onUnhandledError},\n * or by throwing it on that new job so it can be picked up by `window.onerror`, `process.on('error')`, etc.\n *\n * This should be called whenever there is an error that is out-of-band with the subscription\n * or when an error hits a terminal boundary of the subscription and no error handler was provided.\n *\n * @param err the error to report\n */\nexport function reportUnhandledError(err: any) {\n timeoutProvider.setTimeout(() => {\n const { onUnhandledError } = config;\n if (onUnhandledError) {\n // Execute the user-configured error handler.\n onUnhandledError(err);\n } else {\n // Throw so it is picked up by the runtime's uncaught error mechanism.\n throw err;\n }\n });\n}\n", "/* tslint:disable:no-empty */\nexport function noop() { }\n", "import { CompleteNotification, NextNotification, ErrorNotification } from './types';\n\n/**\n * A completion object optimized for memory use and created to be the\n * same \"shape\" as other notifications in v8.\n * @internal\n */\nexport const COMPLETE_NOTIFICATION = (() => createNotification('C', undefined, undefined) as CompleteNotification)();\n\n/**\n * Internal use only. Creates an optimized error notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function errorNotification(error: any): ErrorNotification {\n return createNotification('E', undefined, error) as any;\n}\n\n/**\n * Internal use only. Creates an optimized next notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function nextNotification(value: T) {\n return createNotification('N', value, undefined) as NextNotification;\n}\n\n/**\n * Ensures that all notifications created internally have the same \"shape\" in v8.\n *\n * TODO: This is only exported to support a crazy legacy test in `groupBy`.\n * @internal\n */\nexport function createNotification(kind: 'N' | 'E' | 'C', value: any, error: any) {\n return {\n kind,\n value,\n error,\n };\n}\n", "import { config } from '../config';\n\nlet context: { errorThrown: boolean; error: any } | null = null;\n\n/**\n * Handles dealing with errors for super-gross mode. Creates a context, in which\n * any synchronously thrown errors will be passed to {@link captureError}. Which\n * will record the error such that it will be rethrown after the call back is complete.\n * TODO: Remove in v8\n * @param cb An immediately executed function.\n */\nexport function errorContext(cb: () => void) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n const isRoot = !context;\n if (isRoot) {\n context = { errorThrown: false, error: null };\n }\n cb();\n if (isRoot) {\n const { errorThrown, error } = context!;\n context = null;\n if (errorThrown) {\n throw error;\n }\n }\n } else {\n // This is the general non-deprecated path for everyone that\n // isn't crazy enough to use super-gross mode (useDeprecatedSynchronousErrorHandling)\n cb();\n }\n}\n\n/**\n * Captures errors only in super-gross mode.\n * @param err the error to capture\n */\nexport function captureError(err: any) {\n if (config.useDeprecatedSynchronousErrorHandling && context) {\n context.errorThrown = true;\n context.error = err;\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { Observer, ObservableNotification } from './types';\nimport { isSubscription, Subscription } from './Subscription';\nimport { config } from './config';\nimport { reportUnhandledError } from './util/reportUnhandledError';\nimport { noop } from './util/noop';\nimport { nextNotification, errorNotification, COMPLETE_NOTIFICATION } from './NotificationFactories';\nimport { timeoutProvider } from './scheduler/timeoutProvider';\nimport { captureError } from './util/errorContext';\n\n/**\n * Implements the {@link Observer} interface and extends the\n * {@link Subscription} class. While the {@link Observer} is the public API for\n * consuming the values of an {@link Observable}, all Observers get converted to\n * a Subscriber, in order to provide Subscription-like capabilities such as\n * `unsubscribe`. Subscriber is a common type in RxJS, and crucial for\n * implementing operators, but it is rarely used as a public API.\n *\n * @class Subscriber\n */\nexport class Subscriber extends Subscription implements Observer {\n /**\n * A static factory for a Subscriber, given a (potentially partial) definition\n * of an Observer.\n * @param next The `next` callback of an Observer.\n * @param error The `error` callback of an\n * Observer.\n * @param complete The `complete` callback of an\n * Observer.\n * @return A Subscriber wrapping the (partially defined)\n * Observer represented by the given arguments.\n * @nocollapse\n * @deprecated Do not use. Will be removed in v8. There is no replacement for this\n * method, and there is no reason to be creating instances of `Subscriber` directly.\n * If you have a specific use case, please file an issue.\n */\n static create(next?: (x?: T) => void, error?: (e?: any) => void, complete?: () => void): Subscriber {\n return new SafeSubscriber(next, error, complete);\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected isStopped: boolean = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected destination: Subscriber | Observer; // this `any` is the escape hatch to erase extra type param (e.g. R)\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * There is no reason to directly create an instance of Subscriber. This type is exported for typings reasons.\n */\n constructor(destination?: Subscriber | Observer) {\n super();\n if (destination) {\n this.destination = destination;\n // Automatically chain subscriptions together here.\n // if destination is a Subscription, then it is a Subscriber.\n if (isSubscription(destination)) {\n destination.add(this);\n }\n } else {\n this.destination = EMPTY_OBSERVER;\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `next` from\n * the Observable, with a value. The Observable may call this method 0 or more\n * times.\n * @param {T} [value] The `next` value.\n * @return {void}\n */\n next(value?: T): void {\n if (this.isStopped) {\n handleStoppedNotification(nextNotification(value), this);\n } else {\n this._next(value!);\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `error` from\n * the Observable, with an attached `Error`. Notifies the Observer that\n * the Observable has experienced an error condition.\n * @param {any} [err] The `error` exception.\n * @return {void}\n */\n error(err?: any): void {\n if (this.isStopped) {\n handleStoppedNotification(errorNotification(err), this);\n } else {\n this.isStopped = true;\n this._error(err);\n }\n }\n\n /**\n * The {@link Observer} callback to receive a valueless notification of type\n * `complete` from the Observable. Notifies the Observer that the Observable\n * has finished sending push-based notifications.\n * @return {void}\n */\n complete(): void {\n if (this.isStopped) {\n handleStoppedNotification(COMPLETE_NOTIFICATION, this);\n } else {\n this.isStopped = true;\n this._complete();\n }\n }\n\n unsubscribe(): void {\n if (!this.closed) {\n this.isStopped = true;\n super.unsubscribe();\n this.destination = null!;\n }\n }\n\n protected _next(value: T): void {\n this.destination.next(value);\n }\n\n protected _error(err: any): void {\n try {\n this.destination.error(err);\n } finally {\n this.unsubscribe();\n }\n }\n\n protected _complete(): void {\n try {\n this.destination.complete();\n } finally {\n this.unsubscribe();\n }\n }\n}\n\n/**\n * This bind is captured here because we want to be able to have\n * compatibility with monoid libraries that tend to use a method named\n * `bind`. In particular, a library called Monio requires this.\n */\nconst _bind = Function.prototype.bind;\n\nfunction bind any>(fn: Fn, thisArg: any): Fn {\n return _bind.call(fn, thisArg);\n}\n\n/**\n * Internal optimization only, DO NOT EXPOSE.\n * @internal\n */\nclass ConsumerObserver implements Observer {\n constructor(private partialObserver: Partial>) {}\n\n next(value: T): void {\n const { partialObserver } = this;\n if (partialObserver.next) {\n try {\n partialObserver.next(value);\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n\n error(err: any): void {\n const { partialObserver } = this;\n if (partialObserver.error) {\n try {\n partialObserver.error(err);\n } catch (error) {\n handleUnhandledError(error);\n }\n } else {\n handleUnhandledError(err);\n }\n }\n\n complete(): void {\n const { partialObserver } = this;\n if (partialObserver.complete) {\n try {\n partialObserver.complete();\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n}\n\nexport class SafeSubscriber extends Subscriber {\n constructor(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((e?: any) => void) | null,\n complete?: (() => void) | null\n ) {\n super();\n\n let partialObserver: Partial>;\n if (isFunction(observerOrNext) || !observerOrNext) {\n // The first argument is a function, not an observer. The next\n // two arguments *could* be observers, or they could be empty.\n partialObserver = {\n next: (observerOrNext ?? undefined) as (((value: T) => void) | undefined),\n error: error ?? undefined,\n complete: complete ?? undefined,\n };\n } else {\n // The first argument is a partial observer.\n let context: any;\n if (this && config.useDeprecatedNextContext) {\n // This is a deprecated path that made `this.unsubscribe()` available in\n // next handler functions passed to subscribe. This only exists behind a flag\n // now, as it is *very* slow.\n context = Object.create(observerOrNext);\n context.unsubscribe = () => this.unsubscribe();\n partialObserver = {\n next: observerOrNext.next && bind(observerOrNext.next, context),\n error: observerOrNext.error && bind(observerOrNext.error, context),\n complete: observerOrNext.complete && bind(observerOrNext.complete, context),\n };\n } else {\n // The \"normal\" path. Just use the partial observer directly.\n partialObserver = observerOrNext;\n }\n }\n\n // Wrap the partial observer to ensure it's a full observer, and\n // make sure proper error handling is accounted for.\n this.destination = new ConsumerObserver(partialObserver);\n }\n}\n\nfunction handleUnhandledError(error: any) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n captureError(error);\n } else {\n // Ideal path, we report this as an unhandled error,\n // which is thrown on a new call stack.\n reportUnhandledError(error);\n }\n}\n\n/**\n * An error handler used when no error handler was supplied\n * to the SafeSubscriber -- meaning no error handler was supplied\n * do the `subscribe` call on our observable.\n * @param err The error to handle\n */\nfunction defaultErrorHandler(err: any) {\n throw err;\n}\n\n/**\n * A handler for notifications that cannot be sent to a stopped subscriber.\n * @param notification The notification being sent\n * @param subscriber The stopped subscriber\n */\nfunction handleStoppedNotification(notification: ObservableNotification, subscriber: Subscriber) {\n const { onStoppedNotification } = config;\n onStoppedNotification && timeoutProvider.setTimeout(() => onStoppedNotification(notification, subscriber));\n}\n\n/**\n * The observer used as a stub for subscriptions where the user did not\n * pass any arguments to `subscribe`. Comes with the default error handling\n * behavior.\n */\nexport const EMPTY_OBSERVER: Readonly> & { closed: true } = {\n closed: true,\n next: noop,\n error: defaultErrorHandler,\n complete: noop,\n};\n", "/**\n * Symbol.observable or a string \"@@observable\". Used for interop\n *\n * @deprecated We will no longer be exporting this symbol in upcoming versions of RxJS.\n * Instead polyfill and use Symbol.observable directly *or* use https://www.npmjs.com/package/symbol-observable\n */\nexport const observable: string | symbol = (() => (typeof Symbol === 'function' && Symbol.observable) || '@@observable')();\n", "/**\n * This function takes one parameter and just returns it. Simply put,\n * this is like `(x: T): T => x`.\n *\n * ## Examples\n *\n * This is useful in some cases when using things like `mergeMap`\n *\n * ```ts\n * import { interval, take, map, range, mergeMap, identity } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(5));\n *\n * const result$ = source$.pipe(\n * map(i => range(i)),\n * mergeMap(identity) // same as mergeMap(x => x)\n * );\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * Or when you want to selectively apply an operator\n *\n * ```ts\n * import { interval, take, identity } from 'rxjs';\n *\n * const shouldLimit = () => Math.random() < 0.5;\n *\n * const source$ = interval(1000);\n *\n * const result$ = source$.pipe(shouldLimit() ? take(5) : identity);\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * @param x Any value that is returned by this function\n * @returns The value passed as the first parameter to this function\n */\nexport function identity(x: T): T {\n return x;\n}\n", "import { identity } from './identity';\nimport { UnaryFunction } from '../types';\n\nexport function pipe(): typeof identity;\nexport function pipe(fn1: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction, fn3: UnaryFunction): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction,\n ...fns: UnaryFunction[]\n): UnaryFunction;\n\n/**\n * pipe() can be called on one or more functions, each of which can take one argument (\"UnaryFunction\")\n * and uses it to return a value.\n * It returns a function that takes one argument, passes it to the first UnaryFunction, and then\n * passes the result to the next one, passes that result to the next one, and so on. \n */\nexport function pipe(...fns: Array>): UnaryFunction {\n return pipeFromArray(fns);\n}\n\n/** @internal */\nexport function pipeFromArray(fns: Array>): UnaryFunction {\n if (fns.length === 0) {\n return identity as UnaryFunction;\n }\n\n if (fns.length === 1) {\n return fns[0];\n }\n\n return function piped(input: T): R {\n return fns.reduce((prev: any, fn: UnaryFunction) => fn(prev), input as any);\n };\n}\n", "import { Operator } from './Operator';\nimport { SafeSubscriber, Subscriber } from './Subscriber';\nimport { isSubscription, Subscription } from './Subscription';\nimport { TeardownLogic, OperatorFunction, Subscribable, Observer } from './types';\nimport { observable as Symbol_observable } from './symbol/observable';\nimport { pipeFromArray } from './util/pipe';\nimport { config } from './config';\nimport { isFunction } from './util/isFunction';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A representation of any set of values over any amount of time. This is the most basic building block\n * of RxJS.\n *\n * @class Observable\n */\nexport class Observable implements Subscribable {\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n source: Observable | undefined;\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n operator: Operator | undefined;\n\n /**\n * @constructor\n * @param {Function} subscribe the function that is called when the Observable is\n * initially subscribed to. This function is given a Subscriber, to which new values\n * can be `next`ed, or an `error` method can be called to raise an error, or\n * `complete` can be called to notify of a successful completion.\n */\n constructor(subscribe?: (this: Observable, subscriber: Subscriber) => TeardownLogic) {\n if (subscribe) {\n this._subscribe = subscribe;\n }\n }\n\n // HACK: Since TypeScript inherits static properties too, we have to\n // fight against TypeScript here so Subject can have a different static create signature\n /**\n * Creates a new Observable by calling the Observable constructor\n * @owner Observable\n * @method create\n * @param {Function} subscribe? the subscriber function to be passed to the Observable constructor\n * @return {Observable} a new observable\n * @nocollapse\n * @deprecated Use `new Observable()` instead. Will be removed in v8.\n */\n static create: (...args: any[]) => any = (subscribe?: (subscriber: Subscriber) => TeardownLogic) => {\n return new Observable(subscribe);\n };\n\n /**\n * Creates a new Observable, with this Observable instance as the source, and the passed\n * operator defined as the new observable's operator.\n * @method lift\n * @param operator the operator defining the operation to take on the observable\n * @return a new observable with the Operator applied\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * If you have implemented an operator using `lift`, it is recommended that you create an\n * operator by simply returning `new Observable()` directly. See \"Creating new operators from\n * scratch\" section here: https://rxjs.dev/guide/operators\n */\n lift(operator?: Operator): Observable {\n const observable = new Observable();\n observable.source = this;\n observable.operator = operator;\n return observable;\n }\n\n subscribe(observerOrNext?: Partial> | ((value: T) => void)): Subscription;\n /** @deprecated Instead of passing separate callback arguments, use an observer argument. Signatures taking separate callback arguments will be removed in v8. Details: https://rxjs.dev/deprecations/subscribe-arguments */\n subscribe(next?: ((value: T) => void) | null, error?: ((error: any) => void) | null, complete?: (() => void) | null): Subscription;\n /**\n * Invokes an execution of an Observable and registers Observer handlers for notifications it will emit.\n *\n * Use it when you have all these Observables, but still nothing is happening.\n *\n * `subscribe` is not a regular operator, but a method that calls Observable's internal `subscribe` function. It\n * might be for example a function that you passed to Observable's constructor, but most of the time it is\n * a library implementation, which defines what will be emitted by an Observable, and when it be will emitted. This means\n * that calling `subscribe` is actually the moment when Observable starts its work, not when it is created, as it is often\n * the thought.\n *\n * Apart from starting the execution of an Observable, this method allows you to listen for values\n * that an Observable emits, as well as for when it completes or errors. You can achieve this in two\n * of the following ways.\n *\n * The first way is creating an object that implements {@link Observer} interface. It should have methods\n * defined by that interface, but note that it should be just a regular JavaScript object, which you can create\n * yourself in any way you want (ES6 class, classic function constructor, object literal etc.). In particular, do\n * not attempt to use any RxJS implementation details to create Observers - you don't need them. Remember also\n * that your object does not have to implement all methods. If you find yourself creating a method that doesn't\n * do anything, you can simply omit it. Note however, if the `error` method is not provided and an error happens,\n * it will be thrown asynchronously. Errors thrown asynchronously cannot be caught using `try`/`catch`. Instead,\n * use the {@link onUnhandledError} configuration option or use a runtime handler (like `window.onerror` or\n * `process.on('error)`) to be notified of unhandled errors. Because of this, it's recommended that you provide\n * an `error` method to avoid missing thrown errors.\n *\n * The second way is to give up on Observer object altogether and simply provide callback functions in place of its methods.\n * This means you can provide three functions as arguments to `subscribe`, where the first function is equivalent\n * of a `next` method, the second of an `error` method and the third of a `complete` method. Just as in case of an Observer,\n * if you do not need to listen for something, you can omit a function by passing `undefined` or `null`,\n * since `subscribe` recognizes these functions by where they were placed in function call. When it comes\n * to the `error` function, as with an Observer, if not provided, errors emitted by an Observable will be thrown asynchronously.\n *\n * You can, however, subscribe with no parameters at all. This may be the case where you're not interested in terminal events\n * and you also handled emissions internally by using operators (e.g. using `tap`).\n *\n * Whichever style of calling `subscribe` you use, in both cases it returns a Subscription object.\n * This object allows you to call `unsubscribe` on it, which in turn will stop the work that an Observable does and will clean\n * up all resources that an Observable used. Note that cancelling a subscription will not call `complete` callback\n * provided to `subscribe` function, which is reserved for a regular completion signal that comes from an Observable.\n *\n * Remember that callbacks provided to `subscribe` are not guaranteed to be called asynchronously.\n * It is an Observable itself that decides when these functions will be called. For example {@link of}\n * by default emits all its values synchronously. Always check documentation for how given Observable\n * will behave when subscribed and if its default behavior can be modified with a `scheduler`.\n *\n * #### Examples\n *\n * Subscribe with an {@link guide/observer Observer}\n *\n * ```ts\n * import { of } from 'rxjs';\n *\n * const sumObserver = {\n * sum: 0,\n * next(value) {\n * console.log('Adding: ' + value);\n * this.sum = this.sum + value;\n * },\n * error() {\n * // We actually could just remove this method,\n * // since we do not really care about errors right now.\n * },\n * complete() {\n * console.log('Sum equals: ' + this.sum);\n * }\n * };\n *\n * of(1, 2, 3) // Synchronously emits 1, 2, 3 and then completes.\n * .subscribe(sumObserver);\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Subscribe with functions ({@link deprecations/subscribe-arguments deprecated})\n *\n * ```ts\n * import { of } from 'rxjs'\n *\n * let sum = 0;\n *\n * of(1, 2, 3).subscribe(\n * value => {\n * console.log('Adding: ' + value);\n * sum = sum + value;\n * },\n * undefined,\n * () => console.log('Sum equals: ' + sum)\n * );\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Cancel a subscription\n *\n * ```ts\n * import { interval } from 'rxjs';\n *\n * const subscription = interval(1000).subscribe({\n * next(num) {\n * console.log(num)\n * },\n * complete() {\n * // Will not be called, even when cancelling subscription.\n * console.log('completed!');\n * }\n * });\n *\n * setTimeout(() => {\n * subscription.unsubscribe();\n * console.log('unsubscribed!');\n * }, 2500);\n *\n * // Logs:\n * // 0 after 1s\n * // 1 after 2s\n * // 'unsubscribed!' after 2.5s\n * ```\n *\n * @param {Observer|Function} observerOrNext (optional) Either an observer with methods to be called,\n * or the first of three possible handlers, which is the handler for each value emitted from the subscribed\n * Observable.\n * @param {Function} error (optional) A handler for a terminal event resulting from an error. If no error handler is provided,\n * the error will be thrown asynchronously as unhandled.\n * @param {Function} complete (optional) A handler for a terminal event resulting from successful completion.\n * @return {Subscription} a subscription reference to the registered handlers\n * @method subscribe\n */\n subscribe(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((error: any) => void) | null,\n complete?: (() => void) | null\n ): Subscription {\n const subscriber = isSubscriber(observerOrNext) ? observerOrNext : new SafeSubscriber(observerOrNext, error, complete);\n\n errorContext(() => {\n const { operator, source } = this;\n subscriber.add(\n operator\n ? // We're dealing with a subscription in the\n // operator chain to one of our lifted operators.\n operator.call(subscriber, source)\n : source\n ? // If `source` has a value, but `operator` does not, something that\n // had intimate knowledge of our API, like our `Subject`, must have\n // set it. We're going to just call `_subscribe` directly.\n this._subscribe(subscriber)\n : // In all other cases, we're likely wrapping a user-provided initializer\n // function, so we need to catch errors and handle them appropriately.\n this._trySubscribe(subscriber)\n );\n });\n\n return subscriber;\n }\n\n /** @internal */\n protected _trySubscribe(sink: Subscriber): TeardownLogic {\n try {\n return this._subscribe(sink);\n } catch (err) {\n // We don't need to return anything in this case,\n // because it's just going to try to `add()` to a subscription\n // above.\n sink.error(err);\n }\n }\n\n /**\n * Used as a NON-CANCELLABLE means of subscribing to an observable, for use with\n * APIs that expect promises, like `async/await`. You cannot unsubscribe from this.\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * #### Example\n *\n * ```ts\n * import { interval, take } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(4));\n *\n * async function getTotal() {\n * let total = 0;\n *\n * await source$.forEach(value => {\n * total += value;\n * console.log('observable -> ' + value);\n * });\n *\n * return total;\n * }\n *\n * getTotal().then(\n * total => console.log('Total: ' + total)\n * );\n *\n * // Expected:\n * // 'observable -> 0'\n * // 'observable -> 1'\n * // 'observable -> 2'\n * // 'observable -> 3'\n * // 'Total: 6'\n * ```\n *\n * @param next a handler for each value emitted by the observable\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n */\n forEach(next: (value: T) => void): Promise;\n\n /**\n * @param next a handler for each value emitted by the observable\n * @param promiseCtor a constructor function used to instantiate the Promise\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n * @deprecated Passing a Promise constructor will no longer be available\n * in upcoming versions of RxJS. This is because it adds weight to the library, for very\n * little benefit. If you need this functionality, it is recommended that you either\n * polyfill Promise, or you create an adapter to convert the returned native promise\n * to whatever promise implementation you wanted. Will be removed in v8.\n */\n forEach(next: (value: T) => void, promiseCtor: PromiseConstructorLike): Promise;\n\n forEach(next: (value: T) => void, promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n const subscriber = new SafeSubscriber({\n next: (value) => {\n try {\n next(value);\n } catch (err) {\n reject(err);\n subscriber.unsubscribe();\n }\n },\n error: reject,\n complete: resolve,\n });\n this.subscribe(subscriber);\n }) as Promise;\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): TeardownLogic {\n return this.source?.subscribe(subscriber);\n }\n\n /**\n * An interop point defined by the es7-observable spec https://github.com/zenparsing/es-observable\n * @method Symbol.observable\n * @return {Observable} this instance of the observable\n */\n [Symbol_observable]() {\n return this;\n }\n\n /* tslint:disable:max-line-length */\n pipe(): Observable;\n pipe(op1: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction, op3: OperatorFunction): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction,\n ...operations: OperatorFunction[]\n ): Observable;\n /* tslint:enable:max-line-length */\n\n /**\n * Used to stitch together functional operators into a chain.\n * @method pipe\n * @return {Observable} the Observable result of all of the operators having\n * been called in the order they were passed in.\n *\n * ## Example\n *\n * ```ts\n * import { interval, filter, map, scan } from 'rxjs';\n *\n * interval(1000)\n * .pipe(\n * filter(x => x % 2 === 0),\n * map(x => x + x),\n * scan((acc, x) => acc + x)\n * )\n * .subscribe(x => console.log(x));\n * ```\n */\n pipe(...operations: OperatorFunction[]): Observable {\n return pipeFromArray(operations)(this);\n }\n\n /* tslint:disable:max-line-length */\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: typeof Promise): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: PromiseConstructorLike): Promise;\n /* tslint:enable:max-line-length */\n\n /**\n * Subscribe to this Observable and get a Promise resolving on\n * `complete` with the last emission (if any).\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * @method toPromise\n * @param [promiseCtor] a constructor function used to instantiate\n * the Promise\n * @return A Promise that resolves with the last value emit, or\n * rejects on an error. If there were no emissions, Promise\n * resolves with undefined.\n * @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise\n */\n toPromise(promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n let value: T | undefined;\n this.subscribe(\n (x: T) => (value = x),\n (err: any) => reject(err),\n () => resolve(value)\n );\n }) as Promise;\n }\n}\n\n/**\n * Decides between a passed promise constructor from consuming code,\n * A default configured promise constructor, and the native promise\n * constructor and returns it. If nothing can be found, it will throw\n * an error.\n * @param promiseCtor The optional promise constructor to passed by consuming code\n */\nfunction getPromiseCtor(promiseCtor: PromiseConstructorLike | undefined) {\n return promiseCtor ?? config.Promise ?? Promise;\n}\n\nfunction isObserver(value: any): value is Observer {\n return value && isFunction(value.next) && isFunction(value.error) && isFunction(value.complete);\n}\n\nfunction isSubscriber(value: any): value is Subscriber {\n return (value && value instanceof Subscriber) || (isObserver(value) && isSubscription(value));\n}\n", "import { Observable } from '../Observable';\nimport { Subscriber } from '../Subscriber';\nimport { OperatorFunction } from '../types';\nimport { isFunction } from './isFunction';\n\n/**\n * Used to determine if an object is an Observable with a lift function.\n */\nexport function hasLift(source: any): source is { lift: InstanceType['lift'] } {\n return isFunction(source?.lift);\n}\n\n/**\n * Creates an `OperatorFunction`. Used to define operators throughout the library in a concise way.\n * @param init The logic to connect the liftedSource to the subscriber at the moment of subscription.\n */\nexport function operate(\n init: (liftedSource: Observable, subscriber: Subscriber) => (() => void) | void\n): OperatorFunction {\n return (source: Observable) => {\n if (hasLift(source)) {\n return source.lift(function (this: Subscriber, liftedSource: Observable) {\n try {\n return init(liftedSource, this);\n } catch (err) {\n this.error(err);\n }\n });\n }\n throw new TypeError('Unable to lift unknown Observable type');\n };\n}\n", "import { Subscriber } from '../Subscriber';\n\n/**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional teardown logic here. This will only be called on teardown if the\n * subscriber itself is not already closed. This is called after all other teardown logic is executed.\n */\nexport function createOperatorSubscriber(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n onFinalize?: () => void\n): Subscriber {\n return new OperatorSubscriber(destination, onNext, onComplete, onError, onFinalize);\n}\n\n/**\n * A generic helper for allowing operators to be created with a Subscriber and\n * use closures to capture necessary state from the operator function itself.\n */\nexport class OperatorSubscriber extends Subscriber {\n /**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional finalization logic here. This will only be called on finalization if the\n * subscriber itself is not already closed. This is called after all other finalization logic is executed.\n * @param shouldUnsubscribe An optional check to see if an unsubscribe call should truly unsubscribe.\n * NOTE: This currently **ONLY** exists to support the strange behavior of {@link groupBy}, where unsubscription\n * to the resulting observable does not actually disconnect from the source if there are active subscriptions\n * to any grouped observable. (DO NOT EXPOSE OR USE EXTERNALLY!!!)\n */\n constructor(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n private onFinalize?: () => void,\n private shouldUnsubscribe?: () => boolean\n ) {\n // It's important - for performance reasons - that all of this class's\n // members are initialized and that they are always initialized in the same\n // order. This will ensure that all OperatorSubscriber instances have the\n // same hidden class in V8. This, in turn, will help keep the number of\n // hidden classes involved in property accesses within the base class as\n // low as possible. If the number of hidden classes involved exceeds four,\n // the property accesses will become megamorphic and performance penalties\n // will be incurred - i.e. inline caches won't be used.\n //\n // The reasons for ensuring all instances have the same hidden class are\n // further discussed in this blog post from Benedikt Meurer:\n // https://benediktmeurer.de/2018/03/23/impact-of-polymorphism-on-component-based-frameworks-like-react/\n super(destination);\n this._next = onNext\n ? function (this: OperatorSubscriber, value: T) {\n try {\n onNext(value);\n } catch (err) {\n destination.error(err);\n }\n }\n : super._next;\n this._error = onError\n ? function (this: OperatorSubscriber, err: any) {\n try {\n onError(err);\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._error;\n this._complete = onComplete\n ? function (this: OperatorSubscriber) {\n try {\n onComplete();\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._complete;\n }\n\n unsubscribe() {\n if (!this.shouldUnsubscribe || this.shouldUnsubscribe()) {\n const { closed } = this;\n super.unsubscribe();\n // Execute additional teardown if we have any and we didn't already do so.\n !closed && this.onFinalize?.();\n }\n }\n}\n", "import { Subscription } from '../Subscription';\n\ninterface AnimationFrameProvider {\n schedule(callback: FrameRequestCallback): Subscription;\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n delegate:\n | {\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n }\n | undefined;\n}\n\nexport const animationFrameProvider: AnimationFrameProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n schedule(callback) {\n let request = requestAnimationFrame;\n let cancel: typeof cancelAnimationFrame | undefined = cancelAnimationFrame;\n const { delegate } = animationFrameProvider;\n if (delegate) {\n request = delegate.requestAnimationFrame;\n cancel = delegate.cancelAnimationFrame;\n }\n const handle = request((timestamp) => {\n // Clear the cancel function. The request has been fulfilled, so\n // attempting to cancel the request upon unsubscription would be\n // pointless.\n cancel = undefined;\n callback(timestamp);\n });\n return new Subscription(() => cancel?.(handle));\n },\n requestAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.requestAnimationFrame || requestAnimationFrame)(...args);\n },\n cancelAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.cancelAnimationFrame || cancelAnimationFrame)(...args);\n },\n delegate: undefined,\n};\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface ObjectUnsubscribedError extends Error {}\n\nexport interface ObjectUnsubscribedErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (): ObjectUnsubscribedError;\n}\n\n/**\n * An error thrown when an action is invalid because the object has been\n * unsubscribed.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n *\n * @class ObjectUnsubscribedError\n */\nexport const ObjectUnsubscribedError: ObjectUnsubscribedErrorCtor = createErrorClass(\n (_super) =>\n function ObjectUnsubscribedErrorImpl(this: any) {\n _super(this);\n this.name = 'ObjectUnsubscribedError';\n this.message = 'object unsubscribed';\n }\n);\n", "import { Operator } from './Operator';\nimport { Observable } from './Observable';\nimport { Subscriber } from './Subscriber';\nimport { Subscription, EMPTY_SUBSCRIPTION } from './Subscription';\nimport { Observer, SubscriptionLike, TeardownLogic } from './types';\nimport { ObjectUnsubscribedError } from './util/ObjectUnsubscribedError';\nimport { arrRemove } from './util/arrRemove';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A Subject is a special type of Observable that allows values to be\n * multicasted to many Observers. Subjects are like EventEmitters.\n *\n * Every Subject is an Observable and an Observer. You can subscribe to a\n * Subject, and you can call next to feed values as well as error and complete.\n */\nexport class Subject extends Observable implements SubscriptionLike {\n closed = false;\n\n private currentObservers: Observer[] | null = null;\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n observers: Observer[] = [];\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n isStopped = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n hasError = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n thrownError: any = null;\n\n /**\n * Creates a \"subject\" by basically gluing an observer to an observable.\n *\n * @nocollapse\n * @deprecated Recommended you do not use. Will be removed at some point in the future. Plans for replacement still under discussion.\n */\n static create: (...args: any[]) => any = (destination: Observer, source: Observable): AnonymousSubject => {\n return new AnonymousSubject(destination, source);\n };\n\n constructor() {\n // NOTE: This must be here to obscure Observable's constructor.\n super();\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n lift(operator: Operator): Observable {\n const subject = new AnonymousSubject(this, this);\n subject.operator = operator as any;\n return subject as any;\n }\n\n /** @internal */\n protected _throwIfClosed() {\n if (this.closed) {\n throw new ObjectUnsubscribedError();\n }\n }\n\n next(value: T) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n if (!this.currentObservers) {\n this.currentObservers = Array.from(this.observers);\n }\n for (const observer of this.currentObservers) {\n observer.next(value);\n }\n }\n });\n }\n\n error(err: any) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.hasError = this.isStopped = true;\n this.thrownError = err;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.error(err);\n }\n }\n });\n }\n\n complete() {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.isStopped = true;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.complete();\n }\n }\n });\n }\n\n unsubscribe() {\n this.isStopped = this.closed = true;\n this.observers = this.currentObservers = null!;\n }\n\n get observed() {\n return this.observers?.length > 0;\n }\n\n /** @internal */\n protected _trySubscribe(subscriber: Subscriber): TeardownLogic {\n this._throwIfClosed();\n return super._trySubscribe(subscriber);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._checkFinalizedStatuses(subscriber);\n return this._innerSubscribe(subscriber);\n }\n\n /** @internal */\n protected _innerSubscribe(subscriber: Subscriber) {\n const { hasError, isStopped, observers } = this;\n if (hasError || isStopped) {\n return EMPTY_SUBSCRIPTION;\n }\n this.currentObservers = null;\n observers.push(subscriber);\n return new Subscription(() => {\n this.currentObservers = null;\n arrRemove(observers, subscriber);\n });\n }\n\n /** @internal */\n protected _checkFinalizedStatuses(subscriber: Subscriber) {\n const { hasError, thrownError, isStopped } = this;\n if (hasError) {\n subscriber.error(thrownError);\n } else if (isStopped) {\n subscriber.complete();\n }\n }\n\n /**\n * Creates a new Observable with this Subject as the source. You can do this\n * to create custom Observer-side logic of the Subject and conceal it from\n * code that uses the Observable.\n * @return {Observable} Observable that the Subject casts to\n */\n asObservable(): Observable {\n const observable: any = new Observable();\n observable.source = this;\n return observable;\n }\n}\n\n/**\n * @class AnonymousSubject\n */\nexport class AnonymousSubject extends Subject {\n constructor(\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n public destination?: Observer,\n source?: Observable\n ) {\n super();\n this.source = source;\n }\n\n next(value: T) {\n this.destination?.next?.(value);\n }\n\n error(err: any) {\n this.destination?.error?.(err);\n }\n\n complete() {\n this.destination?.complete?.();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n return this.source?.subscribe(subscriber) ?? EMPTY_SUBSCRIPTION;\n }\n}\n", "import { Subject } from './Subject';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\n\n/**\n * A variant of Subject that requires an initial value and emits its current\n * value whenever it is subscribed to.\n *\n * @class BehaviorSubject\n */\nexport class BehaviorSubject extends Subject {\n constructor(private _value: T) {\n super();\n }\n\n get value(): T {\n return this.getValue();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n const subscription = super._subscribe(subscriber);\n !subscription.closed && subscriber.next(this._value);\n return subscription;\n }\n\n getValue(): T {\n const { hasError, thrownError, _value } = this;\n if (hasError) {\n throw thrownError;\n }\n this._throwIfClosed();\n return _value;\n }\n\n next(value: T): void {\n super.next((this._value = value));\n }\n}\n", "import { TimestampProvider } from '../types';\n\ninterface DateTimestampProvider extends TimestampProvider {\n delegate: TimestampProvider | undefined;\n}\n\nexport const dateTimestampProvider: DateTimestampProvider = {\n now() {\n // Use the variable rather than `this` so that the function can be called\n // without being bound to the provider.\n return (dateTimestampProvider.delegate || Date).now();\n },\n delegate: undefined,\n};\n", "import { Subject } from './Subject';\nimport { TimestampProvider } from './types';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * A variant of {@link Subject} that \"replays\" old values to new subscribers by emitting them when they first subscribe.\n *\n * `ReplaySubject` has an internal buffer that will store a specified number of values that it has observed. Like `Subject`,\n * `ReplaySubject` \"observes\" values by having them passed to its `next` method. When it observes a value, it will store that\n * value for a time determined by the configuration of the `ReplaySubject`, as passed to its constructor.\n *\n * When a new subscriber subscribes to the `ReplaySubject` instance, it will synchronously emit all values in its buffer in\n * a First-In-First-Out (FIFO) manner. The `ReplaySubject` will also complete, if it has observed completion; and it will\n * error if it has observed an error.\n *\n * There are two main configuration items to be concerned with:\n *\n * 1. `bufferSize` - This will determine how many items are stored in the buffer, defaults to infinite.\n * 2. `windowTime` - The amount of time to hold a value in the buffer before removing it from the buffer.\n *\n * Both configurations may exist simultaneously. So if you would like to buffer a maximum of 3 values, as long as the values\n * are less than 2 seconds old, you could do so with a `new ReplaySubject(3, 2000)`.\n *\n * ### Differences with BehaviorSubject\n *\n * `BehaviorSubject` is similar to `new ReplaySubject(1)`, with a couple of exceptions:\n *\n * 1. `BehaviorSubject` comes \"primed\" with a single value upon construction.\n * 2. `ReplaySubject` will replay values, even after observing an error, where `BehaviorSubject` will not.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n * @see {@link shareReplay}\n */\nexport class ReplaySubject extends Subject {\n private _buffer: (T | number)[] = [];\n private _infiniteTimeWindow = true;\n\n /**\n * @param bufferSize The size of the buffer to replay on subscription\n * @param windowTime The amount of time the buffered items will stay buffered\n * @param timestampProvider An object with a `now()` method that provides the current timestamp. This is used to\n * calculate the amount of time something has been buffered.\n */\n constructor(\n private _bufferSize = Infinity,\n private _windowTime = Infinity,\n private _timestampProvider: TimestampProvider = dateTimestampProvider\n ) {\n super();\n this._infiniteTimeWindow = _windowTime === Infinity;\n this._bufferSize = Math.max(1, _bufferSize);\n this._windowTime = Math.max(1, _windowTime);\n }\n\n next(value: T): void {\n const { isStopped, _buffer, _infiniteTimeWindow, _timestampProvider, _windowTime } = this;\n if (!isStopped) {\n _buffer.push(value);\n !_infiniteTimeWindow && _buffer.push(_timestampProvider.now() + _windowTime);\n }\n this._trimBuffer();\n super.next(value);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._trimBuffer();\n\n const subscription = this._innerSubscribe(subscriber);\n\n const { _infiniteTimeWindow, _buffer } = this;\n // We use a copy here, so reentrant code does not mutate our array while we're\n // emitting it to a new subscriber.\n const copy = _buffer.slice();\n for (let i = 0; i < copy.length && !subscriber.closed; i += _infiniteTimeWindow ? 1 : 2) {\n subscriber.next(copy[i] as T);\n }\n\n this._checkFinalizedStatuses(subscriber);\n\n return subscription;\n }\n\n private _trimBuffer() {\n const { _bufferSize, _timestampProvider, _buffer, _infiniteTimeWindow } = this;\n // If we don't have an infinite buffer size, and we're over the length,\n // use splice to truncate the old buffer values off. Note that we have to\n // double the size for instances where we're not using an infinite time window\n // because we're storing the values and the timestamps in the same array.\n const adjustedBufferSize = (_infiniteTimeWindow ? 1 : 2) * _bufferSize;\n _bufferSize < Infinity && adjustedBufferSize < _buffer.length && _buffer.splice(0, _buffer.length - adjustedBufferSize);\n\n // Now, if we're not in an infinite time window, remove all values where the time is\n // older than what is allowed.\n if (!_infiniteTimeWindow) {\n const now = _timestampProvider.now();\n let last = 0;\n // Search the array for the first timestamp that isn't expired and\n // truncate the buffer up to that point.\n for (let i = 1; i < _buffer.length && (_buffer[i] as number) <= now; i += 2) {\n last = i;\n }\n last && _buffer.splice(0, last + 1);\n }\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Subscription } from '../Subscription';\nimport { SchedulerAction } from '../types';\n\n/**\n * A unit of work to be executed in a `scheduler`. An action is typically\n * created from within a {@link SchedulerLike} and an RxJS user does not need to concern\n * themselves about creating and manipulating an Action.\n *\n * ```ts\n * class Action extends Subscription {\n * new (scheduler: Scheduler, work: (state?: T) => void);\n * schedule(state?: T, delay: number = 0): Subscription;\n * }\n * ```\n *\n * @class Action\n */\nexport class Action extends Subscription {\n constructor(scheduler: Scheduler, work: (this: SchedulerAction, state?: T) => void) {\n super();\n }\n /**\n * Schedules this action on its parent {@link SchedulerLike} for execution. May be passed\n * some context object, `state`. May happen at some point in the future,\n * according to the `delay` parameter, if specified.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler.\n * @return {void}\n */\n public schedule(state?: T, delay: number = 0): Subscription {\n return this;\n }\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetIntervalFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearIntervalFunction = (handle: TimerHandle) => void;\n\ninterface IntervalProvider {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n delegate:\n | {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n }\n | undefined;\n}\n\nexport const intervalProvider: IntervalProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setInterval(handler: () => void, timeout?: number, ...args) {\n const { delegate } = intervalProvider;\n if (delegate?.setInterval) {\n return delegate.setInterval(handler, timeout, ...args);\n }\n return setInterval(handler, timeout, ...args);\n },\n clearInterval(handle) {\n const { delegate } = intervalProvider;\n return (delegate?.clearInterval || clearInterval)(handle as any);\n },\n delegate: undefined,\n};\n", "import { Action } from './Action';\nimport { SchedulerAction } from '../types';\nimport { Subscription } from '../Subscription';\nimport { AsyncScheduler } from './AsyncScheduler';\nimport { intervalProvider } from './intervalProvider';\nimport { arrRemove } from '../util/arrRemove';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncAction extends Action {\n public id: TimerHandle | undefined;\n public state?: T;\n // @ts-ignore: Property has no initializer and is not definitely assigned\n public delay: number;\n protected pending: boolean = false;\n\n constructor(protected scheduler: AsyncScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (this.closed) {\n return this;\n }\n\n // Always replace the current state with the new state.\n this.state = state;\n\n const id = this.id;\n const scheduler = this.scheduler;\n\n //\n // Important implementation note:\n //\n // Actions only execute once by default, unless rescheduled from within the\n // scheduled callback. This allows us to implement single and repeat\n // actions via the same code path, without adding API surface area, as well\n // as mimic traditional recursion but across asynchronous boundaries.\n //\n // However, JS runtimes and timers distinguish between intervals achieved by\n // serial `setTimeout` calls vs. a single `setInterval` call. An interval of\n // serial `setTimeout` calls can be individually delayed, which delays\n // scheduling the next `setTimeout`, and so on. `setInterval` attempts to\n // guarantee the interval callback will be invoked more precisely to the\n // interval period, regardless of load.\n //\n // Therefore, we use `setInterval` to schedule single and repeat actions.\n // If the action reschedules itself with the same delay, the interval is not\n // canceled. If the action doesn't reschedule, or reschedules with a\n // different delay, the interval will be canceled after scheduled callback\n // execution.\n //\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, delay);\n }\n\n // Set the pending flag indicating that this action has been scheduled, or\n // has recursively rescheduled itself.\n this.pending = true;\n\n this.delay = delay;\n // If this action has already an async Id, don't request a new one.\n this.id = this.id ?? this.requestAsyncId(scheduler, this.id, delay);\n\n return this;\n }\n\n protected requestAsyncId(scheduler: AsyncScheduler, _id?: TimerHandle, delay: number = 0): TimerHandle {\n return intervalProvider.setInterval(scheduler.flush.bind(scheduler, this), delay);\n }\n\n protected recycleAsyncId(_scheduler: AsyncScheduler, id?: TimerHandle, delay: number | null = 0): TimerHandle | undefined {\n // If this action is rescheduled with the same delay time, don't clear the interval id.\n if (delay != null && this.delay === delay && this.pending === false) {\n return id;\n }\n // Otherwise, if the action's delay time is different from the current delay,\n // or the action has been rescheduled before it's executed, clear the interval id\n if (id != null) {\n intervalProvider.clearInterval(id);\n }\n\n return undefined;\n }\n\n /**\n * Immediately executes this action and the `work` it contains.\n * @return {any}\n */\n public execute(state: T, delay: number): any {\n if (this.closed) {\n return new Error('executing a cancelled action');\n }\n\n this.pending = false;\n const error = this._execute(state, delay);\n if (error) {\n return error;\n } else if (this.pending === false && this.id != null) {\n // Dequeue if the action didn't reschedule itself. Don't call\n // unsubscribe(), because the action could reschedule later.\n // For example:\n // ```\n // scheduler.schedule(function doWork(counter) {\n // /* ... I'm a busy worker bee ... */\n // var originalAction = this;\n // /* wait 100ms before rescheduling the action */\n // setTimeout(function () {\n // originalAction.schedule(counter + 1);\n // }, 100);\n // }, 1000);\n // ```\n this.id = this.recycleAsyncId(this.scheduler, this.id, null);\n }\n }\n\n protected _execute(state: T, _delay: number): any {\n let errored: boolean = false;\n let errorValue: any;\n try {\n this.work(state);\n } catch (e) {\n errored = true;\n // HACK: Since code elsewhere is relying on the \"truthiness\" of the\n // return here, we can't have it return \"\" or 0 or false.\n // TODO: Clean this up when we refactor schedulers mid-version-8 or so.\n errorValue = e ? e : new Error('Scheduled action threw falsy error');\n }\n if (errored) {\n this.unsubscribe();\n return errorValue;\n }\n }\n\n unsubscribe() {\n if (!this.closed) {\n const { id, scheduler } = this;\n const { actions } = scheduler;\n\n this.work = this.state = this.scheduler = null!;\n this.pending = false;\n\n arrRemove(actions, this);\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, null);\n }\n\n this.delay = null!;\n super.unsubscribe();\n }\n }\n}\n", "import { Action } from './scheduler/Action';\nimport { Subscription } from './Subscription';\nimport { SchedulerLike, SchedulerAction } from './types';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * An execution context and a data structure to order tasks and schedule their\n * execution. Provides a notion of (potentially virtual) time, through the\n * `now()` getter method.\n *\n * Each unit of work in a Scheduler is called an `Action`.\n *\n * ```ts\n * class Scheduler {\n * now(): number;\n * schedule(work, delay?, state?): Subscription;\n * }\n * ```\n *\n * @class Scheduler\n * @deprecated Scheduler is an internal implementation detail of RxJS, and\n * should not be used directly. Rather, create your own class and implement\n * {@link SchedulerLike}. Will be made internal in v8.\n */\nexport class Scheduler implements SchedulerLike {\n public static now: () => number = dateTimestampProvider.now;\n\n constructor(private schedulerActionCtor: typeof Action, now: () => number = Scheduler.now) {\n this.now = now;\n }\n\n /**\n * A getter method that returns a number representing the current time\n * (at the time this function was called) according to the scheduler's own\n * internal clock.\n * @return {number} A number that represents the current time. May or may not\n * have a relation to wall-clock time. May or may not refer to a time unit\n * (e.g. milliseconds).\n */\n public now: () => number;\n\n /**\n * Schedules a function, `work`, for execution. May happen at some point in\n * the future, according to the `delay` parameter, if specified. May be passed\n * some context object, `state`, which will be passed to the `work` function.\n *\n * The given arguments will be processed an stored as an Action object in a\n * queue of actions.\n *\n * @param {function(state: ?T): ?Subscription} work A function representing a\n * task, or some unit of work to be executed by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler itself.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @return {Subscription} A subscription in order to be able to unsubscribe\n * the scheduled work.\n */\n public schedule(work: (this: SchedulerAction, state?: T) => void, delay: number = 0, state?: T): Subscription {\n return new this.schedulerActionCtor(this, work).schedule(state, delay);\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Action } from './Action';\nimport { AsyncAction } from './AsyncAction';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncScheduler extends Scheduler {\n public actions: Array> = [];\n /**\n * A flag to indicate whether the Scheduler is currently executing a batch of\n * queued actions.\n * @type {boolean}\n * @internal\n */\n public _active: boolean = false;\n /**\n * An internal ID used to track the latest asynchronous task such as those\n * coming from `setTimeout`, `setInterval`, `requestAnimationFrame`, and\n * others.\n * @type {any}\n * @internal\n */\n public _scheduled: TimerHandle | undefined;\n\n constructor(SchedulerAction: typeof Action, now: () => number = Scheduler.now) {\n super(SchedulerAction, now);\n }\n\n public flush(action: AsyncAction): void {\n const { actions } = this;\n\n if (this._active) {\n actions.push(action);\n return;\n }\n\n let error: any;\n this._active = true;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions.shift()!)); // exhaust the scheduler queue\n\n this._active = false;\n\n if (error) {\n while ((action = actions.shift()!)) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\n/**\n *\n * Async Scheduler\n *\n * Schedule task as if you used setTimeout(task, duration)\n *\n * `async` scheduler schedules tasks asynchronously, by putting them on the JavaScript\n * event loop queue. It is best used to delay tasks in time or to schedule tasks repeating\n * in intervals.\n *\n * If you just want to \"defer\" task, that is to perform it right after currently\n * executing synchronous code ends (commonly achieved by `setTimeout(deferredTask, 0)`),\n * better choice will be the {@link asapScheduler} scheduler.\n *\n * ## Examples\n * Use async scheduler to delay task\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * const task = () => console.log('it works!');\n *\n * asyncScheduler.schedule(task, 2000);\n *\n * // After 2 seconds logs:\n * // \"it works!\"\n * ```\n *\n * Use async scheduler to repeat task in intervals\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * function task(state) {\n * console.log(state);\n * this.schedule(state + 1, 1000); // `this` references currently executing Action,\n * // which we reschedule with new state and delay\n * }\n *\n * asyncScheduler.schedule(task, 3000, 0);\n *\n * // Logs:\n * // 0 after 3s\n * // 1 after 4s\n * // 2 after 5s\n * // 3 after 6s\n * ```\n */\n\nexport const asyncScheduler = new AsyncScheduler(AsyncAction);\n\n/**\n * @deprecated Renamed to {@link asyncScheduler}. Will be removed in v8.\n */\nexport const async = asyncScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { Subscription } from '../Subscription';\nimport { QueueScheduler } from './QueueScheduler';\nimport { SchedulerAction } from '../types';\nimport { TimerHandle } from './timerHandle';\n\nexport class QueueAction extends AsyncAction {\n constructor(protected scheduler: QueueScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (delay > 0) {\n return super.schedule(state, delay);\n }\n this.delay = delay;\n this.state = state;\n this.scheduler.flush(this);\n return this;\n }\n\n public execute(state: T, delay: number): any {\n return delay > 0 || this.closed ? super.execute(state, delay) : this._execute(state, delay);\n }\n\n protected requestAsyncId(scheduler: QueueScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n\n if ((delay != null && delay > 0) || (delay == null && this.delay > 0)) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n\n // Otherwise flush the scheduler starting with this action.\n scheduler.flush(this);\n\n // HACK: In the past, this was returning `void`. However, `void` isn't a valid\n // `TimerHandle`, and generally the return value here isn't really used. So the\n // compromise is to return `0` which is both \"falsy\" and a valid `TimerHandle`,\n // as opposed to refactoring every other instanceo of `requestAsyncId`.\n return 0;\n }\n}\n", "import { AsyncScheduler } from './AsyncScheduler';\n\nexport class QueueScheduler extends AsyncScheduler {\n}\n", "import { QueueAction } from './QueueAction';\nimport { QueueScheduler } from './QueueScheduler';\n\n/**\n *\n * Queue Scheduler\n *\n * Put every next task on a queue, instead of executing it immediately\n *\n * `queue` scheduler, when used with delay, behaves the same as {@link asyncScheduler} scheduler.\n *\n * When used without delay, it schedules given task synchronously - executes it right when\n * it is scheduled. However when called recursively, that is when inside the scheduled task,\n * another task is scheduled with queue scheduler, instead of executing immediately as well,\n * that task will be put on a queue and wait for current one to finish.\n *\n * This means that when you execute task with `queue` scheduler, you are sure it will end\n * before any other task scheduled with that scheduler will start.\n *\n * ## Examples\n * Schedule recursively first, then do something\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(() => {\n * queueScheduler.schedule(() => console.log('second')); // will not happen now, but will be put on a queue\n *\n * console.log('first');\n * });\n *\n * // Logs:\n * // \"first\"\n * // \"second\"\n * ```\n *\n * Reschedule itself recursively\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(function(state) {\n * if (state !== 0) {\n * console.log('before', state);\n * this.schedule(state - 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * console.log('after', state);\n * }\n * }, 0, 3);\n *\n * // In scheduler that runs recursively, you would expect:\n * // \"before\", 3\n * // \"before\", 2\n * // \"before\", 1\n * // \"after\", 1\n * // \"after\", 2\n * // \"after\", 3\n *\n * // But with queue it logs:\n * // \"before\", 3\n * // \"after\", 3\n * // \"before\", 2\n * // \"after\", 2\n * // \"before\", 1\n * // \"after\", 1\n * ```\n */\n\nexport const queueScheduler = new QueueScheduler(QueueAction);\n\n/**\n * @deprecated Renamed to {@link queueScheduler}. Will be removed in v8.\n */\nexport const queue = queueScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\nimport { SchedulerAction } from '../types';\nimport { animationFrameProvider } from './animationFrameProvider';\nimport { TimerHandle } from './timerHandle';\n\nexport class AnimationFrameAction extends AsyncAction {\n constructor(protected scheduler: AnimationFrameScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n protected requestAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay is greater than 0, request as an async action.\n if (delay !== null && delay > 0) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n // Push the action to the end of the scheduler queue.\n scheduler.actions.push(this);\n // If an animation frame has already been requested, don't request another\n // one. If an animation frame hasn't been requested yet, request one. Return\n // the current animation frame request id.\n return scheduler._scheduled || (scheduler._scheduled = animationFrameProvider.requestAnimationFrame(() => scheduler.flush(undefined)));\n }\n\n protected recycleAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle | undefined {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n if (delay != null ? delay > 0 : this.delay > 0) {\n return super.recycleAsyncId(scheduler, id, delay);\n }\n // If the scheduler queue has no remaining actions with the same async id,\n // cancel the requested animation frame and set the scheduled flag to\n // undefined so the next AnimationFrameAction will request its own.\n const { actions } = scheduler;\n if (id != null && actions[actions.length - 1]?.id !== id) {\n animationFrameProvider.cancelAnimationFrame(id as number);\n scheduler._scheduled = undefined;\n }\n // Return undefined so the action knows to request a new async id if it's rescheduled.\n return undefined;\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\nexport class AnimationFrameScheduler extends AsyncScheduler {\n public flush(action?: AsyncAction): void {\n this._active = true;\n // The async id that effects a call to flush is stored in _scheduled.\n // Before executing an action, it's necessary to check the action's async\n // id to determine whether it's supposed to be executed in the current\n // flush.\n // Previous implementations of this method used a count to determine this,\n // but that was unsound, as actions that are unsubscribed - i.e. cancelled -\n // are removed from the actions array and that can shift actions that are\n // scheduled to be executed in a subsequent flush into positions at which\n // they are executed within the current flush.\n const flushId = this._scheduled;\n this._scheduled = undefined;\n\n const { actions } = this;\n let error: any;\n action = action || actions.shift()!;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions[0]) && action.id === flushId && actions.shift());\n\n this._active = false;\n\n if (error) {\n while ((action = actions[0]) && action.id === flushId && actions.shift()) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AnimationFrameAction } from './AnimationFrameAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\n\n/**\n *\n * Animation Frame Scheduler\n *\n * Perform task when `window.requestAnimationFrame` would fire\n *\n * When `animationFrame` scheduler is used with delay, it will fall back to {@link asyncScheduler} scheduler\n * behaviour.\n *\n * Without delay, `animationFrame` scheduler can be used to create smooth browser animations.\n * It makes sure scheduled task will happen just before next browser content repaint,\n * thus performing animations as efficiently as possible.\n *\n * ## Example\n * Schedule div height animation\n * ```ts\n * // html:
\n * import { animationFrameScheduler } from 'rxjs';\n *\n * const div = document.querySelector('div');\n *\n * animationFrameScheduler.schedule(function(height) {\n * div.style.height = height + \"px\";\n *\n * this.schedule(height + 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * }, 0, 0);\n *\n * // You will see a div element growing in height\n * ```\n */\n\nexport const animationFrameScheduler = new AnimationFrameScheduler(AnimationFrameAction);\n\n/**\n * @deprecated Renamed to {@link animationFrameScheduler}. Will be removed in v8.\n */\nexport const animationFrame = animationFrameScheduler;\n", "import { Observable } from '../Observable';\nimport { SchedulerLike } from '../types';\n\n/**\n * A simple Observable that emits no items to the Observer and immediately\n * emits a complete notification.\n *\n * Just emits 'complete', and nothing else.\n *\n * ![](empty.png)\n *\n * A simple Observable that only emits the complete notification. It can be used\n * for composing with other Observables, such as in a {@link mergeMap}.\n *\n * ## Examples\n *\n * Log complete notification\n *\n * ```ts\n * import { EMPTY } from 'rxjs';\n *\n * EMPTY.subscribe({\n * next: () => console.log('Next'),\n * complete: () => console.log('Complete!')\n * });\n *\n * // Outputs\n * // Complete!\n * ```\n *\n * Emit the number 7, then complete\n *\n * ```ts\n * import { EMPTY, startWith } from 'rxjs';\n *\n * const result = EMPTY.pipe(startWith(7));\n * result.subscribe(x => console.log(x));\n *\n * // Outputs\n * // 7\n * ```\n *\n * Map and flatten only odd numbers to the sequence `'a'`, `'b'`, `'c'`\n *\n * ```ts\n * import { interval, mergeMap, of, EMPTY } from 'rxjs';\n *\n * const interval$ = interval(1000);\n * const result = interval$.pipe(\n * mergeMap(x => x % 2 === 1 ? of('a', 'b', 'c') : EMPTY),\n * );\n * result.subscribe(x => console.log(x));\n *\n * // Results in the following to the console:\n * // x is equal to the count on the interval, e.g. (0, 1, 2, 3, ...)\n * // x will occur every 1000ms\n * // if x % 2 is equal to 1, print a, b, c (each on its own)\n * // if x % 2 is not equal to 1, nothing will be output\n * ```\n *\n * @see {@link Observable}\n * @see {@link NEVER}\n * @see {@link of}\n * @see {@link throwError}\n */\nexport const EMPTY = new Observable((subscriber) => subscriber.complete());\n\n/**\n * @param scheduler A {@link SchedulerLike} to use for scheduling\n * the emission of the complete notification.\n * @deprecated Replaced with the {@link EMPTY} constant or {@link scheduled} (e.g. `scheduled([], scheduler)`). Will be removed in v8.\n */\nexport function empty(scheduler?: SchedulerLike) {\n return scheduler ? emptyScheduled(scheduler) : EMPTY;\n}\n\nfunction emptyScheduled(scheduler: SchedulerLike) {\n return new Observable((subscriber) => scheduler.schedule(() => subscriber.complete()));\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport function isScheduler(value: any): value is SchedulerLike {\n return value && isFunction(value.schedule);\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\nimport { isScheduler } from './isScheduler';\n\nfunction last(arr: T[]): T | undefined {\n return arr[arr.length - 1];\n}\n\nexport function popResultSelector(args: any[]): ((...args: unknown[]) => unknown) | undefined {\n return isFunction(last(args)) ? args.pop() : undefined;\n}\n\nexport function popScheduler(args: any[]): SchedulerLike | undefined {\n return isScheduler(last(args)) ? args.pop() : undefined;\n}\n\nexport function popNumber(args: any[], defaultValue: number): number {\n return typeof last(args) === 'number' ? args.pop()! : defaultValue;\n}\n", "export const isArrayLike = ((x: any): x is ArrayLike => x && typeof x.length === 'number' && typeof x !== 'function');", "import { isFunction } from \"./isFunction\";\n\n/**\n * Tests to see if the object is \"thennable\".\n * @param value the object to test\n */\nexport function isPromise(value: any): value is PromiseLike {\n return isFunction(value?.then);\n}\n", "import { InteropObservable } from '../types';\nimport { observable as Symbol_observable } from '../symbol/observable';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being Observable (but not necessary an Rx Observable) */\nexport function isInteropObservable(input: any): input is InteropObservable {\n return isFunction(input[Symbol_observable]);\n}\n", "import { isFunction } from './isFunction';\n\nexport function isAsyncIterable(obj: any): obj is AsyncIterable {\n return Symbol.asyncIterator && isFunction(obj?.[Symbol.asyncIterator]);\n}\n", "/**\n * Creates the TypeError to throw if an invalid object is passed to `from` or `scheduled`.\n * @param input The object that was passed.\n */\nexport function createInvalidObservableTypeError(input: any) {\n // TODO: We should create error codes that can be looked up, so this can be less verbose.\n return new TypeError(\n `You provided ${\n input !== null && typeof input === 'object' ? 'an invalid object' : `'${input}'`\n } where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.`\n );\n}\n", "export function getSymbolIterator(): symbol {\n if (typeof Symbol !== 'function' || !Symbol.iterator) {\n return '@@iterator' as any;\n }\n\n return Symbol.iterator;\n}\n\nexport const iterator = getSymbolIterator();\n", "import { iterator as Symbol_iterator } from '../symbol/iterator';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being an Iterable */\nexport function isIterable(input: any): input is Iterable {\n return isFunction(input?.[Symbol_iterator]);\n}\n", "import { ReadableStreamLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport async function* readableStreamLikeToAsyncGenerator(readableStream: ReadableStreamLike): AsyncGenerator {\n const reader = readableStream.getReader();\n try {\n while (true) {\n const { value, done } = await reader.read();\n if (done) {\n return;\n }\n yield value!;\n }\n } finally {\n reader.releaseLock();\n }\n}\n\nexport function isReadableStreamLike(obj: any): obj is ReadableStreamLike {\n // We don't want to use instanceof checks because they would return\n // false for instances from another Realm, like an + + +
+
+Computational Photography +

Computational Photography topic deals with digital image capture based on optical hardware such as cameras. +Common examples of emerging Computational Photography are smartphone applications such as shooting in the dark or capturing selfies. +Today, we all use products of Computational Photography to capture glimpses from our daily lives and store them as memories.

+
    +
  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Displays. +Here are some examples of such people; I would encourage you to explore their websites: Diego Gutierrez and Jinwei Gu.
  • +
  • Successful products. Here are a few examples of successful outcomes from the field of Computational Displays: Google's Night Sight and Samsung Camera modes.
  • +
  • Want to learn more? Although we will cover relevant information for Computational Photography in this course, you may want to dig deeper with a dedicated course, which you can follow online: +
  • +
+
+
+Computational Imaging and Sensing +

Computational Imaging and Sensing topic deal with imaging and sensing certain scene qualities. +Common examples of Computational Imaging and Sensing can be found in the two other domains of Computational Light: Computational Astronomy and Computational Microscopy. +Today, medical diagnoses of biological samples in hospitals or imaging stars and beyond or sensing vital signals are all products of Computational Imaging and Sensing.

+
    +
  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Imaging and Sensing. +Here are some examples of such people; I would encourage you to explore their websites: Laura Waller and Nick Antipa.
  • +
  • Successful products. Here are a few examples of successful outcomes from the field of Computational Imaging and Sensing: Zeiss Microscopes and Heart rate sensors on Apple's Smartwatch.
  • +
  • Did you know? The lecturer of the Computational Light Course, Kaan Akşit, is actively researching topics of Computational Imaging and Displays (e.g., Unrolled Primal-Dual Networks for Lensless Cameras 7).
  • +
  • Want to learn more? Although we will cover a great deal of Computational Imaging and Sensing in this course, you may want to dig deeper with a dedicated course, which you can follow online: +
  • +
+
+
+Computational Optics and Fabrication +

The Computational Optics and Fabrication topic deals with designing and fabricating optical components such as lenses, mirrors, diffraction gratings, holographic optical elements, and metasurfaces. + There is a little bit of Computational Optics and Fabrication in every sector of Computational Light, especially when there is a need for custom optical design.

+
    +
  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Optics and Fabrication. +Here are some examples of such people; I would encourage you to explore their websites: Jannick Rolland and Mark Pauly.
  • +
  • Did you know? The lecturer of the Computational Light Course, Kaan Akşit, is actively researching topics of Computational Optics and Fabrication (e.g., Manufacturing application-driven foveated near-eye displays 8).
  • +
  • Want to learn more? Although we will cover a great deal of Computational Imaging and Sensing in this course, you may want to dig deeper with a dedicated course, which you can follow online: +
  • +
+
+
+Optical Communication +

Optical Communication deals with using light as a medium for telecommunication applications. +Common examples of Optical Communication are the fiber cables and satellites equipped with optical links in space running our Internet. +In today's world, Optical Communication runs our entire modern life by making the Internet a reality.

+
    +
  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of modern Optical Communication. +Here are some people I would encourage you to explore their websites: Harald Haas and Anna Maria Vegni.
  • +
  • Did you know? The lecturer of the Computational Light Course, Kaan Akşit, was researching topics of Optical Communication (e.g., From sound to sight: Using audio processing to enable visible light communication 9).
  • +
  • Want to learn more? Although we will cover relevant information for Optical Communication in this course, you may want to dig deeper and could start with this online video: +
  • +
+
+
+All-optical Machine Learning +

All-optical Machine Learning deals with building neural networks and computers running solely based on light. +As this is an emerging field, there are yet to be products in this field that we use in our daily lives. +But this also means there are opportunities for newcomers and investors in this space.

+
    +
  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of All-optical Machine Learning. +Here are some people I would encourage you to explore their websites: Aydogan Ozcan and Ugur Tegin.
  • +
  • Want to learn more? Although we will cover a great deal of All-optical Machine Learning in this course, you may want to dig deeper with a dedicated course, which you can follow online: +
  • +
+
+
+Lab work: What are the other fields and interesting profiles out there? +

Please explore other relevant fields to Computational Light, and explore interesting profiles out there. +Please make a list of relevant fields and interesting profiles and report your top three.

+
+

Indeed, there are more topics related to computational light than the ones highlighted here. +If you are up to a challenge for the next phase of your life, you could help the field identify new opportunities with light-based sciences. +In addition, there are indeed more topics, more noticeable profiles, successful product examples, and dedicated courses that focus on every one of these topics. +Examples are not limited to the ones that I have provided above. +Your favorite search engine is your friend to find out more in this case.

+
+Lab work: Where do we find good resources? +

Please explore software projects on GitHub and papers on Google Scholar to find out about works that are relevant to the theme of Computational Light. +Please make a list of these projects and report the top three projects that you feel most exciting and interesting.

+
+
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+
+
+
    +
  1. +

    David Keun Cheng and others. Fundamentals of engineering electromagnetics. Addison-Wesley Reading, MA, 1993. 

    +
  2. +
  3. +

    Chandra Roychoudhuri, Al F Kracklauer, and Kathy Creath. The nature of light: What is a photon? CRC Press, 2017. 

    +
  4. +
  5. +

    Eugene Hecht. Optics. Pearson Education India, 2012. 

    +
  6. +
  7. +

    David R Walton, Rafael Kuffner Dos Anjos, Sebastian Friston, David Swapp, Kaan Akşit, Anthony Steed, and Tobias Ritschel. Beyond blur: real-time ventral metamers for foveated rendering. ACM Transactions on Graphics, 40(4):1–14, 2021. 

    +
  8. +
  9. +

    Kaan Akşit, Ward Lopes, Jonghyun Kim, Peter Shirley, and David Luebke. Near-eye varifocal augmented reality display using see-through screens. ACM Transactions on Graphics (TOG), 36(6):1–13, 2017. 

    +
  10. +
  11. +

    Koray Kavakli, David Robert Walton, Nick Antipa, Rafał Mantiuk, Douglas Lanman, and Kaan Akşit. Optimizing vision and visuals: lectures on cameras, displays and perception. In ACM SIGGRAPH 2022 Courses, pages 1–66. 2022. 

    +
  12. +
  13. +

    Oliver Kingshott, Nick Antipa, Emrah Bostan, and Kaan Akşit. Unrolled primal-dual networks for lensless cameras. Optics Express, 30(26):46324–46335, 2022. 

    +
  14. +
  15. +

    Kaan Akşit, Praneeth Chakravarthula, Kishore Rathinavel, Youngmo Jeong, Rachel Albert, Henry Fuchs, and David Luebke. Manufacturing application-driven foveated near-eye displays. IEEE transactions on visualization and computer graphics, 25(5):1928–1939, 2019. 

    +
  16. +
  17. +

    Stefan Schmid, Daniel Schwyn, Kaan Akşit, Giorgio Corbellini, Thomas R Gross, and Stefan Mangold. From sound to sight: using audio processing to enable visible light communication. In 2014 IEEE Globecom Workshops (GC Wkshps), 518–523. IEEE, 2014. 

    +
  18. +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/computer_generated_holography/index.html b/course/computer_generated_holography/index.html new file mode 100644 index 00000000..b5503d11 --- /dev/null +++ b/course/computer_generated_holography/index.html @@ -0,0 +1,4240 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Computer-Generated Holography - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +
+Narrate section +

+
+

Computer-Generated Holography

+

In this section, we introduce Computer-Generated Holography (CGH) 12 as another emerging method to simulate light. +CGH offers an upgraded but more computationally expensive way to simulating light concerning the raytracing method described in the previous section. +This section dives deep into CGH and will explain how CGH differs from raytracing as we go.

+

What is holography?

+

Informative

+

Holography is a method in Optical sciences to represent light distribution using amplitude and phase of light. +In much simpler terms, holography describes light distribution emitted from an object, scene, or illumination source over a surface by treating the light as a wave. +The primary difference of holography concerning raytracing is that it accounts not only amplitude or intensity of light but also the phase of light. +Unlike classical raytracing, holography also includes diffraction and interference phenomena. +In raytracing, the smallest building block that defines light is a ray, whereas, in holography, the building block is a light distribution over surfaces. +In other terms, while raytracing traces rays, holography deals with surface-to-surface light transfer.

+
+Did you know this source? +

There is an active repository on GitHub, where latest CGH papers relevant to display technologies are listed. +Visit GitHub:bchao1/awesome-holography for more.

+
+

What is a hologram?

+

Informative

+

Hologram is either a surface or a volume that modifies the light distribution of incoming light in terms of phase and amplitude. +Diffraction gratings, Holographic Optical Elements, or Metasurfaces are good examples of holograms. +Within this section, we also use the term hologram as a means to describe a lightfield or a slice of a lightfield.

+

What is Computer-Generated Holography?

+

Informative

+

It is the computerized version (discrete sampling) of holography. +In other terms, whenever you can program the phase or amplitude of light, this will get us to Computer-Generated Holography.

+
+Where can I find an extensive summary on CGH? +

You may be wondering about the greater physical details of CGH. +In this case, we suggest our readers watch the video below. +Please watch this video for an extensive summary on CGH 3. +

+
+

Defining a slice of a lightfield

+

Informative · + Practical

+

CGH deals with generating optical fields that capture light from various scenes. +CGH often describes these optical fields (a.k.a. lightfields, holograms) as planes. +So in CGH, light travels from plane to plane, as depicted below. +Roughly, CGH deals with plane to plane interaction of light, whereas raytracing is a ray or beam oriented description of light.

+
+

Image title +

+
A rendering showing how a slice (a.k.a. lightfield, optical field, hologram) propagates from one plane to another plane.
+
+

In other words, in CGH, you define everything as a "lightfield," including light sources, materials, and objects. +Thus, we must first determine how to describe the mentioned lightfield in a computer. +So that we can run CGH simulations effectively.

+

A lightfield is a planar slice in the context of CGH, as depicted in the above figure. +This planar field is a pixelated 2D surface (could be represented as a matrix). +The pixels in this 2D slice hold values for the amplitude of light, \(A\), and the phase of the light, \(\phi\) at each pixel. +Whereas in classical raytracing, a ray only holds the amplitude or intensity of light. +With a caveat, though, raytracing could also be made to care about the phase of light. +Still, it will then arrive with all the complications of raytracing, like sampling enough rays or describing scenes accurately.

+

Each pixel in this planar lightfield slice encapsulates the \(A\) and \(\phi\) as \(A cos(wt + \phi)\). +If you recall our description of light, we explain that light is an electromagnetic phenomenon. +Here, we model the oscillating electric field of light with \(A cos(wt + \phi)\) shown in our previous light description. +Note that if we stick to \(A cos(wt + \phi)\), each time two fields intersect, we have to deal with trigonometric conversion complexities like sampled in this example:

+
\[ +A_0 cos(wt + \phi_0) + A_1 cos(wt + \phi_1), +\]
+

Where the indices zero and one indicate the first and second fields, and we have to identify the right trigonometric conversion to deal with this sum.

+

Instead of complicated trigonometric conversions, what people do in CGH is to rely on complex numbers as a proxy to these trigonometric conversions. +In its proxy form, a pixel value in a field is converted into \(A e^{-j \phi}\), where \(j\) represents a complex number (\(\sqrt{-1}\)). +Thus, with this new proxy representation, the same intersection problem we dealt with using sophisticated trigonometry before could be turned into something as simple as \(A_0 A_1 e^{-j(\phi_0 +\phi_1)}\).

+

In the above summation of two fields, the resulting field follows an exact sum of the two collided fields. +On the other hand, in raytracing, often, when a ray intersects with another ray, it will be left unchanged and continue its path. +However, in the case of lightfields, they form a new field. +This feature is called interference of light, which is not introduced in raytracing, and often raytracing omits this feature. +As you can tell from also the summation, two fields could enhance the resulting field (constructive interference) by converging to a brighter intensity, or these two fields could cancel out each other (destructive interference) and lead to the absence of light --total darkness--.

+

There are various examples of interference in nature. +For example, the blue color of a butterfly wing results from interference, as biology typically does not produce blue-colored pigments in nature. +More examples of light interference from daily lives are provided in the figure below.

+
+

Image title +

+
Two photographs showin some examples of light interference: (left) thin oil film creates rainbow interference patterns (CC BY-SA 2.5 by Wikipedia user John) and a soup bubble interference with light and creates vivid reflections (CC BY-SA 3.0 by Wikipedia user Brocken Inaglory).
+
+

We have established an easy way to describe a field with a proxy complex number form. +This way, we avoided complicated trigonometric conversions. +Let us look into how we use that in an actual simulation. +Firstly, we can define two separate matrices to represent a field using real numbers:

+
import torch
+
+amplitude = torch.tensor(100, 100, dtype = torch.float64)
+phase = torch.tensor(100, 100, dtype = torch.float64)
+
+

In this above example, we define two matrices with 100 x 100 dimensions. +Each matrix holds floating point numbers, and they are real numbers. +To convert the amplitude and phase into a field, we must define the field as suggested in our previous description. +Instead of going through the same mathematical process for every piece of our future codes, we can rely on a utility function in odak to create fields consistently and coherently across all our future developments. +The utility function we will review is odak.learn.wave.generate_complex_field():

+

Here, we provide visual results from this piece of code as below:

+
+
+
+ + +
+ + + + +
+ +

Definition to generate a complex field with a given amplitude and phase.

+ + +

Parameters:

+
    +
  • + amplitude + – +
    +
                Amplitude of the field.
    +            The expected size is [m x n] or [1 x m x n].
    +
    +
    +
  • +
  • + phase + – +
    +
                Phase of the field.
    +            The expected size is [m x n] or [1 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( ndarray +) – +
    +

    Complex field. +Depending on the input, the expected size is [m x n] or [1 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
def generate_complex_field(amplitude, phase):
+    """
+    Definition to generate a complex field with a given amplitude and phase.
+
+    Parameters
+    ----------
+    amplitude         : torch.tensor
+                        Amplitude of the field.
+                        The expected size is [m x n] or [1 x m x n].
+    phase             : torch.tensor
+                        Phase of the field.
+                        The expected size is [m x n] or [1 x m x n].
+
+    Returns
+    -------
+    field             : ndarray
+                        Complex field.
+                        Depending on the input, the expected size is [m x n] or [1 x m x n].
+    """
+    field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)
+    return field
+
+
+
+ +
+
+
+

Let us use this utility function to expand our previous code snippet and show how we can generate a complex field using that:

+
import torch
+import odak # (1)
+
+amplitude = torch.tensor(100, 100, dtype = torch.float64)
+phase = torch.tensor(100, 100, dtype = torch.float64)
+field = odak.learn.wave.generate_complex_field(amplitude, phase) # (2)
+
+
    +
  1. Adding odak to our imports.
  2. +
  3. Generating a field using odak.learn.wave.generate_complex_field.
  4. +
+

Propagating a field in free space

+

Informative · + Practical

+

The next question we have to ask is related to the field we generated in our previous example. +In raytracing, we propagate rays in space, whereas in CGH, we propagate a field described over a surface onto another target surface. +So we need a transfer function that projects our field on another target surface. +That is the point where free space beam propagation comes into play. +As the name implies, free space beam propagation deals with propagating light in free space from one surface to another. +This entire process of propagation is also referred to as light transport in the domains of Computer Graphics. +In the rest of this section, we will explore means to simulate beam propagation on a computer.

+
+A good news for Matlab fans! +

We will indeed use odak to explore beam propagation. +However, there is also a book in the literature, [Numerical simulation of optical wave propagation: With examples in MATLAB by Jason D. Schmidt](https://www.spiedigitallibrary.org/ebooks/PM/Numerical-Simulation-of-Optical-Wave-Propagation-with-Examples-in-MATLAB/eISBN-9780819483270/10.1117/3.866274?SSO=1)4, that provides a crash course on beam propagation using MATLAB.

+
+

As we revisit the field we generated in the previous subsection, we remember that our field is a pixelated 2D surface. +Each pixel in our fields, either a hologram or image plane, typically has a small size of a few micrometers (e.g., \(8 \mu m\)). +How light travels from each one of these pixels on one surface to pixels on another is conceptually depicted as a figure at the beginning of this section (green wolf image with two planes). +We will name that figure's first plane on the left as the hologram plane and the second as the image plane. +In a nutshell, the contribution of a pixel on a hologram plane could be calculated by drawing rays to every pixel on the image plane. +We draw rays from a point to a plane because in wave theory --what CGH follows--, light can diffract (a small aperture creating spherical waves as Huygens suggested). +Each ray will have a certain distance, thus causing various delays in phase \(\phi\). +As long as the distance between planes is large enough, each ray will maintain an electric field that is in the same direction as the others (same polarization), thus able to interfere with other rays emerging from other pixels in a hologram plane. +This simplified description oversimplifies solving the Maxwell equations in electromagnetics.

+

A simplified result of solving Maxwell's equation is commonly described using Rayleigh-Sommerfeld diffraction integrals. +For more on Rayleigh-Sommerfeld, consult Heurtley, J. C. (1973). Scalar Rayleigh–Sommerfeld and Kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8), 1003-1008. 5. +The first solution of the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle, is expressed as follows:

+
\[ +u(x,y)=\frac{1}{j\lambda} \int\!\!\!\!\int u_0(x,y)\frac{e^{jkr}}{r}cos(\theta)dxdy, +\]
+

where the field at a target image plane, \(u(x,y)\), is calculated by integrating over every point of the hologram's area, \(u_0(x,y)\). +Note that, for the above equation, \(r\) represents the optical path between a selected point over a hologram and a selected point in the image plane, theta represents the angle between these two points, k represents the wavenumber (\(\frac{2\pi}{\lambda}\)) and \(\lambda\) represents the wavelength of light. +In this described light transport model, optical fields, \(u_0(x,y)\) and \(u(x,y)\), are represented with a complex value,

+
\[ +u_0(x,y)=A(x,y)e^{j\phi(x,y)}, +\]
+

where \(A\) represents the spatial distribution of amplitude and \(\phi\) represents the spatial distribution of phase across a hologram plane. +The described holographic light transport model is often simplified into a single convolution with a fixed spatially invariant complex kernel, \(h(x,y)\) 6.

+
\[ +u(x,y)=u_0(x,y) * h(x,y) =\mathcal{F}^{-1}(\mathcal{F}(u_0(x,y)) \mathcal{F}(h(x,y))). +\]
+

There are multiple variants of this simplified approach:

+ +

In many cases, people choose to use the most common form of \(h(x, y)\) described as

+
\[ +h(x,y)=\frac{e^{jkz}}{j\lambda z} e^{\frac{jk}{2z} (x^2+y^2)}, +\]
+

where z represents the distance between a hologram plane and a target image plane. +Before, we introduce you how to use existing beam propagation in our library, let us dive deep in compiling a beam propagation code following the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle. +In the rest of this script, I will walk you through the below code:

+
+
+
+
 1
+ 2
+ 3
+ 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
import sys
+import odak # (1)
+import torch
+from tqdm import tqdm
+
+
+def main(): # (2)
+    length = [7e-6, 7e-6] # (3)
+    for fresnel_id, fresnel_number in enumerate(range(99)): # (4)
+        fresnel_number += 1
+        propagate(
+                  fresnel_number = fresnel_number,
+                  length = [length[0] + 1. / fresnel_number * 8e-6, length[1] + 1. / fresnel_number * 8e-6]
+                 )
+
+
+def propagate(
+              wavelength = 532e-9, # (6)
+              pixel_pitch = 3.74e-6, # (7)
+              length = [15e-6, 15e-6],
+              image_samples = [2, 2], # Replace it with 1000 by 1000 (8)
+              aperture_samples = [2, 2], # Replace it with 1000 by 1000 (9)
+              device = torch.device('cpu'),
+              output_directory = 'test_output', 
+              fresnel_number = 4,
+              save_flag = False
+             ): # (5)
+    distance = pixel_pitch ** 2 / wavelength / fresnel_number
+    distance = torch.as_tensor(distance, device = device)
+    k = odak.learn.wave.wavenumber(wavelength)
+    x = torch.linspace(- length[0] / 2, length[0] / 2, image_samples[0], device = device)
+    y = torch.linspace(- length[1] / 2, length[1] / 2, image_samples[1], device = device)
+    X, Y = torch.meshgrid(x, y, indexing = 'ij') # (10)
+    wxs = torch.linspace(- pixel_pitch / 2., pixel_pitch / 2., aperture_samples[0], device = device)
+    wys = torch.linspace(- pixel_pitch / 2., pixel_pitch / 2., aperture_samples[1], device = device) # (11)
+    h  = torch.zeros(image_samples[0], image_samples[1], dtype = torch.complex64, device = device)
+    for wx in tqdm(wxs):
+        for wy in wys:
+            h += huygens_fresnel_principle(wx, wy, X, Y, distance, k, wavelength) # (12)
+    h = h * pixel_pitch ** 2 / aperture_samples[0] / aperture_samples[1] # (13) 
+
+    if save_flag:
+        save_results(h, output_directory, fresnel_number, length, pixel_pitch, distance, image_samples, device) # (14)
+    return True
+
+
+def huygens_fresnel_principle(x, y, X, Y, z, k, wavelength): # (12)
+    r = torch.sqrt((X - x) ** 2 + (Y - y) ** 2 + z ** 2)
+    h = torch.exp(1j * k * r) * z / r ** 2 * (1. / (2 * odak.pi * r) + 1. / (1j * wavelength))
+    return h
+
+
+def save_results(h, output_directory, fresnel_number, length, pixel_pitch, distance, image_samples, device):
+    from matplotlib import pyplot as plt
+    odak.tools.check_directory(output_directory)
+    output_intensity = odak.learn.wave.calculate_amplitude(h) ** 2
+    odak.learn.tools.save_image(
+                                '{}/diffraction_output_intensity_fresnel_number_{:02d}.png'.format(output_directory, int(fresnel_number)),
+                                output_intensity,
+                                cmin = 0.,
+                                cmax = output_intensity.max()
+                               )
+    cross_section_1d = output_intensity[output_intensity.shape[0] // 2]
+    lengths = torch.linspace(- length[0] * 10 ** 6 / 2., length[0] * 10 ** 6 / 2., image_samples[0], device = device)
+    plt.figure()
+    plt.plot(lengths.detach().cpu().numpy(), cross_section_1d.detach().cpu().numpy())
+    plt.xlabel('length (um)')
+    plt.figtext(
+                0.14,
+                0.9, 
+                r'Fresnel Number: {:02d}, Pixel pitch: {:.2f} um, Distance: {:.2f} um'.format(fresnel_number, pixel_pitch * 10 ** 6, distance * 10 ** 6),
+                fontsize = 11
+               )
+    plt.savefig('{}/diffraction_1d_output_intensity_fresnel_number_{:02d}.png'.format(output_directory, int(fresnel_number)))
+    plt.cla()
+    plt.clf()
+    plt.close()
+
+
+if __name__ == '__main__':
+    sys.exit(main())
+
+
    +
  1. Importing relevant libraries
  2. +
  3. This is our main routine.
  4. +
  5. Length of the final image plane along X and Y axes.
  6. +
  7. Fresnel number is an arbitrary number that helps to get a sense if the optical configuration could be considered as a Fresnel (near field) or Fraunhofer regions.
  8. +
  9. Propagating light with the given configuration.
  10. +
  11. Wavelength of light.
  12. +
  13. Square aperture length of a single pixel in the simulation. This is where light diffracts from.
  14. +
  15. Number of pixels in the image plane along X and Y axes.
  16. +
  17. Number of point light sources used to represent a single pixel's square aperture.
  18. +
  19. Sample point locations along X and Y axes at the image plane.
  20. +
  21. Sample point locations along X and Y axes at the aperture plane.
  22. +
  23. For each, virtual point light source defined inside the aperture, we simulate the light as if divergind point light source.
  24. +
  25. Normalize with the number of samples (trapezoid integration).
  26. +
  27. Rest of this code is for logistics for saving images.
  28. +
+
+
+
+

We start the implementation by importing necessary libraries such as odak or torch. +The first function, def main, sets the length of our image plane, where we will observe the diffraction pattern. +As we set the size of our image plane, we also set a arbitrary number called Fresnel Number,

+
\[ +n_F = \frac{w^2}{\lambda z}, +\]
+

where \(z\) is the propagation distance, \(w\) is the side length of an aperture diffracting light like a pixel's square aperture -- this is often the pixel pitch -- and \(\lambda\) is the wavelength of the light. +This number helps us to get an idea if the set optical configuration falls under a certain regime like Fresnel or Fraunhofer. +Fresnel number also provides a practical ease related to comparing solutions. +Regardless of the optical configuration, a result with a specific Fresnel number will follow a similar pattern with different optical configuration. +Thus, providing a way to verify your solutions. +In the next step, we call the light propagation function, def propagate. +In the beginning of this function, we set the optical configuration. +For instance, we set pixel_pitch, this is the side length of a square aperture that the light will diffract from. +Inside the def propagate function, we reset the distance such that it follows the input Fresnel Number and wavelength. +We define the locations of the samples across X and Y axes that will represent points to calculate on the image plane, x and y. +Than, we define the locations of the samples across X and Y axes that will represent the point light source locations inside the aperture, wxs and wys, which in this case a square aperture that represents a single pixel and its sidelength is provided by pixel_pitch. +The nested for loop goes over the wxs and wys. +Each time, we choose a point from the aperture, we propagate a spherical wave from that point using def huygens_fresnel_principle. +Note that we accumulate the effect of each spherical wave on a variable called h. +This is diffraction pattern in complex form from our square aperture, and we also normalize it using pixel_pitch and aperture_samples. +Here, we provide visual results from this piece of code as below:

+
+

Image title
+

+
Saved 1D intensities on image plane representing diffraction patterns for various Fresnel numbers. These patterns are generated by using "test/test_diffraction_integral.py".
+
+
+

Image title
+

+
Saved 2D intensities on image plane representing diffraction patterns for various Fresnel numbers. These patterns are generated by using "test/test_diffraction_integral.py".
+
+

Note that beam propagation can also be learned for physical setups to avoid imperfections in a setup and to improve the image quality at an image plane:

+ +

The above descriptions establish a mathematical understanding of beam propagation. +Let us examine the implementation of a beam propagation method called Bandlimited Angular Spectrum by reviewing these two utility functions from odak:

+
+
+
+ + +
+ + + + +
+ +

Helper function for odak.learn.wave.band_limited_angular_spectrum.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( complex64 +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_band_limited_angular_spectrum_kernel(
+                                             nu,
+                                             nv,
+                                             dx = 8e-6,
+                                             wavelength = 515e-9,
+                                             distance = 0.,
+                                             device = torch.device('cpu')
+                                            ):
+    """
+    Helper function for odak.learn.wave.band_limited_angular_spectrum.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+
+
+    Returns
+    -------
+    H                  : torch.complex64
+                         Complex kernel in Fourier domain.
+    """
+    x = dx * float(nu)
+    y = dx * float(nv)
+    fx = torch.linspace(
+                        -1 / (2 * dx) + 0.5 / (2 * x),
+                         1 / (2 * dx) - 0.5 / (2 * x),
+                         nu,
+                         dtype = torch.float32,
+                         device = device
+                        )
+    fy = torch.linspace(
+                        -1 / (2 * dx) + 0.5 / (2 * y),
+                        1 / (2 * dx) - 0.5 / (2 * y),
+                        nv,
+                        dtype = torch.float32,
+                        device = device
+                       )
+    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
+    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
+    distance = torch.tensor([distance], device = device)
+    H_exp = torch.mul(HH_exp, distance)
+    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
+    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
+    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
+    H = generate_complex_field(H_filter, H_exp)
+    return H
+
+
+
+ +
+
+ + +
+ + + + +
+ +

A definition to calculate bandlimited angular spectrum based beam propagation. For more +Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               A complex field.
    +           The expected size is [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field [m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def band_limited_angular_spectrum(
+                                  field,
+                                  k,
+                                  distance,
+                                  dx,
+                                  wavelength,
+                                  zero_padding = False,
+                                  aperture = 1.
+                                 ):
+    """
+    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
+    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       A complex field.
+                       The expected size is [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field [m x n].
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Bandlimited Angular Spectrum',
+                               device = field.device
+                              )
+    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'Bandlimited Angular Spectrum' +) + – +
    +
               Type of the propagation.
    +           The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    +
    +
    +
  • +
  • + kernel + – +
    +
               Custom complex kernel.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero padding the input field if the first item in the list set True.
    +           Zero padding in the Fourier domain if the second item in the list set to True.
    +           Cropping the result with half resolution if the third item in the list is set to true.
    +           Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
    +           If provided as a floating point 1, there will be no aperture in Fourier domain.
    +
    +
    +
  • +
  • + scale + – +
    +
               Resolution factor to scale generated kernel.
    +
    +
    +
  • +
  • + samples + – +
    +
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field [m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
def propagate_beam(
+                   field,
+                   k,
+                   distance,
+                   dx,
+                   wavelength,
+                   propagation_type='Bandlimited Angular Spectrum',
+                   kernel = None,
+                   zero_padding = [True, False, True],
+                   aperture = 1.,
+                   scale = 1,
+                   samples = [20, 20, 5, 5]
+                  ):
+    """
+    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    propagation_type : str
+                       Type of the propagation.
+                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
+    kernel           : torch.complex
+                       Custom complex kernel.
+    zero_padding     : list
+                       Zero padding the input field if the first item in the list set True.
+                       Zero padding in the Fourier domain if the second item in the list set to True.
+                       Cropping the result with half resolution if the third item in the list is set to true.
+                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
+    aperture         : torch.tensor
+                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
+                       If provided as a floating point 1, there will be no aperture in Fourier domain.
+    scale            : int
+                       Resolution factor to scale generated kernel.
+    samples          : list
+                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field [m x n].
+    """
+    if zero_padding[0]:
+        field = zero_pad(field)
+    if propagation_type == 'Angular Spectrum':
+        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Bandlimited Angular Spectrum':
+        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Impulse Response Fresnel':
+        result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
+    elif propagation_type == 'Seperable Impulse Response Fresnel':
+        result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
+    elif propagation_type == 'Transfer Function Fresnel':
+        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'custom':
+        result = custom(field, kernel, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Fraunhofer':
+        result = fraunhofer(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Incoherent Angular Spectrum':
+        result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    else:
+        logging.warning('Propagation type not recognized')
+        assert True == False
+    if zero_padding[2]:
+        result = crop_center(result)
+    return result
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definition for calculating the wavenumber of a plane wave.

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
           Wavelength of a wave in mm.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +k ( float +) – +
    +

    Wave number for a given wavelength.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
def wavenumber(wavelength):
+    """
+    Definition for calculating the wavenumber of a plane wave.
+
+    Parameters
+    ----------
+    wavelength   : float
+                   Wavelength of a wave in mm.
+
+    Returns
+    -------
+    k            : float
+                   Wave number for a given wavelength.
+    """
+    k = 2 * np.pi / wavelength
+    return k
+
+
+
+ +
+
+
+

Let us see how we can use the given beam propagation function with an example:

+
+
+
+
import sys
+import os
+import odak
+import numpy as np
+import torch
+
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    wavelength = 532e-9 # (1)
+    pixel_pitch = 8e-6 # (2)
+    distance = 0.5e-2 # (3)
+    propagation_types = ['Angular Spectrum', 'Bandlimited Angular Spectrum', 'Transfer Function Fresnel'] # (4)
+    k = odak.learn.wave.wavenumber(wavelength) # (5)
+
+
+    amplitude = torch.zeros(500, 500)
+    amplitude[200:300, 200:300 ] = 1. # (5)
+    phase = torch.randn_like(amplitude) * 2 * odak.pi # (6)
+    hologram = odak.learn.wave.generate_complex_field(amplitude, phase) # (7)
+
+    for propagation_type in propagation_types:
+        image_plane = odak.learn.wave.propagate_beam(
+                                                     hologram,
+                                                     k,
+                                                     distance,
+                                                     pixel_pitch,
+                                                     wavelength,
+                                                     propagation_type,
+                                                     zero_padding = [True, False, True] # (8)
+                                                    ) # (9)
+
+        image_intensity = odak.learn.wave.calculate_amplitude(image_plane) ** 2 # (10)
+        hologram_intensity = amplitude ** 2
+
+        odak.learn.tools.save_image(
+                                    '{}/image_intensity_{}.png'.format(output_directory, propagation_type.replace(' ', '_')), 
+                                    image_intensity, 
+                                    cmin = 0., 
+                                    cmax = image_intensity.max()
+                                ) # (11)
+        odak.learn.tools.save_image(
+                                    '{}/hologram_intensity_{}.png'.format(output_directory, propagation_type.replace(' ', '_')), 
+                                    hologram_intensity, 
+                                    cmin = 0., 
+                                    cmax = 1.
+                                ) # (12)
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test()) 
+
+
    +
  1. Setting the wavelength of light in meters. We use 532 nm (green light) in this example.
  2. +
  3. Setting the physical size of a single pixel in our simulation. We use \(6 \mu m\) pixel size (width and height are both \(6 \mu m\).)
  4. +
  5. Setting the distance between two planes, hologram and image plane. We set it as half a centimeterhere.
  6. +
  7. We set the propagation type to Bandlimited Angular Spectrum.
  8. +
  9. Here, we calculate a value named wavenumber, which we introduced while we were talking about the beam propagation functions.
  10. +
  11. Here, we assume that there is a rectangular light at the center of our hologram.
  12. +
  13. Here, we generate the field by combining amplitude and phase.
  14. +
  15. Here, we zeropad and crop our field before and after the beam propagation to make sure that there is no aliasing in our results (see Nyquist criterion).
  16. +
  17. We propagate the beam using the values and field provided.
  18. +
  19. We calculate the final intensity on our image plane. Remember that human eyes can see intensity but not amplitude or phase of light. Intensity of light is a square of its amplitude.
  20. +
  21. We save image plane intensity to an image file.
  22. +
  23. For comparison, we also save the hologram intensity to an image file so that we can observe how our light transformed from one plane to another.
  24. +
+
+
+
+

Let us also take a look at the saved images as a result of the above sample code:

+
+

Image title +

+
Saved intensities before (left_ and after (right) beam propagation (hologram and image plane intensities). This result is generated using "test/test_learn_beam_propagation.py".
+
+
+Challenge: Light transport on Arbitrary Surfaces +

We know that we can propagate a hologram to any image plane at any distance. +This propagation is a plane-to-plane interaction. +However, there may be cases where a simulation involves finding light distribution over an arbitrary surface. +Conventionally, this could be achieved by propagating the hologram to multiple different planes and picking the results from each plane on the surface of that arbitrary surface. +We challenge our readers to code the mentioned baseline (multiple planes for arbitrary surfaces) and ask them to develop a beam propagation that is less computationally expensive and works for arbitrary surfaces (e.g., tilted planes or arbitrary shapes). +This development could either rely on classical approaches or involve learning-based methods. +The resultant method could be part of odak.learn.wave submodule as a new class odak.learn.wave.propagate_arbitrary. +In addition, a unit test test/test_learn_propagate_arbitrary.py has to adopt this new class. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for arbitrary surfaces in docs/notes/beam_propagation_arbitrary_surfaces.md.

+
+

Optimizing holograms

+

Informative · + Practical

+

In the previous subsection, we propagate an input field (a.k.a. lightfield, hologram) to another plane called the image plane. +We can store any scene or object as a field on such planes. +Thus, we have learned that we can have a plane (hologram) to capture or display a slice of a lightfield for any given scene or object. +After all this introduction, it is also safe to say, regardless of hardware, holograms are the natural way to represent three-dimensional scenes, objects, and data!

+

Holograms come in many forms. +We can broadly classify holograms as analog and digital. +Analog holograms are physically tailored structures. +They are typically a result of manufacturing engineered surfaces (micron or nanoscale structures). +Some examples of analog holograms include diffractive optical elements 13, holographic optical elements 14, and metasurfaces 15. +Here, we show an example of an analog hologram that gives us a slice of a lightfield, and we can observe the scene this way from various perspectives:

+
+

Image title +

+
A video showing analog hologram example from Zebra Imaging -- ZScape.
+
+

Digital holograms are the ones that are dynamic and generated using programmable versions of analog holograms. +Typically, the tiniest fraction of digital holograms is a pixel that either manipulates the phase or amplitude of light. +In our laboratory, we build holographic displays 1612, a programmable device to display holograms. +The components used in such a display are illustrated in the following rendering and contain a Spatial Light Modulator (SLM) that could display programmable holograms. +Note that the SLM in this specific hardware can only manipulate phase of an incoming light.

+
+

Image title +

+
A rendering showing a standard holographic display hardware.
+
+

We can display holograms that generate images to fill a three-dimensional volume using the above hardware. +We know that they are three-dimensional from the fact that we can focus on different parts of the images by changing the focus of our camera (closely observing the camera's location in the above figure). +Let us look into a sample result to see what these three-dimensional images look like as we focus on different scene parts.

+
+

Image title +

+
A series of photographs at various focuses capturing images from our computer-generated holograms.
+
+

Let us look into how we can optimize a hologram for our holographic display by visiting the below example:

+
+
+
+
import sys
+import odak
+import torch
+
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    device = torch.device('cpu') # (1)
+    target = odak.learn.tools.load_image('./test/data/usaf1951.png', normalizeby = 255., torch_style = True)[1] # (4)
+    hologram, reconstruction = odak.learn.wave.stochastic_gradient_descent(
+                                                                           target,
+                                                                           wavelength = 532e-9,
+                                                                           distance = 20e-2,
+                                                                           pixel_pitch = 8e-6,
+                                                                           propagation_type = 'Bandlimited Angular Spectrum',
+                                                                           n_iteration = 50,
+                                                                           learning_rate = 0.1
+                                                                          ) # (2)
+    odak.learn.tools.save_image(
+                                '{}/phase.png'.format(output_directory), 
+                                odak.learn.wave.calculate_phase(hologram) % (2 * odak.pi), 
+                                cmin = 0., 
+                                cmax = 2 * odak.pi
+                               ) # (3)
+    odak.learn.tools.save_image('{}/sgd_target.png'.format(output_directory), target, cmin = 0., cmax = 1.)
+    odak.learn.tools.save_image(
+                                '{}/sgd_reconstruction.png'.format(output_directory), 
+                                odak.learn.wave.calculate_amplitude(reconstruction) ** 2, 
+                                cmin = 0., 
+                                cmax = 1.
+                               )
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
    +
  1. Replace cpu with cuda if you have a NVIDIA GPU with enough memory or AMD GPU with enough memory and ROCm support.
  2. +
  3. We will provide the details of this optimization function in the next part.
  4. +
  5. Saving the phase-only hologram. Note that a phase-only hologram is between zero and two pi.
  6. +
  7. Loading an image from a file with 1920 by 1080 resolution and using green channel.
  8. +
+
+
+
+

The above sample optimization script uses a function called odak.learn.wave.stochastic_gradient_descent. +This function sits at the center of this optimization, and we have to understand what it entails by closely observing its inputs, outputs, and source code. +Let us review the function.

+
+
+
+ + +
+ + + + +
+ +

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

+ + +

Parameters:

+
    +
  • + target + – +
    +
                        Target field amplitude [m x n].
    +                    Keep the target values between zero and one.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                        Set if the converted array requires gradient.
    +
    +
    +
  • +
  • + distance + – +
    +
                        Hologram plane distance wrt SLM plane.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
                        SLM pixel pitch in meters.
    +
    +
    +
  • +
  • + propagation_type + – +
    +
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    +
    +
    +
  • +
  • + n_iteration + – +
    +
                        Number of iteration.
    +
    +
    +
  • +
  • + loss_function + – +
    +
                        If none it is set to be l2 loss.
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                        Learning rate.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( Tensor +) – +
    +

    Phase only hologram as torch array

    +
    +
  • +
  • +reconstruction_intensity ( Tensor +) – +
    +

    Reconstruction as torch array

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
+    """
+    Definition to generate phase and reconstruction from target image via stochastic gradient descent.
+
+    Parameters
+    ----------
+    target                    : torch.Tensor
+                                Target field amplitude [m x n].
+                                Keep the target values between zero and one.
+    wavelength                : double
+                                Set if the converted array requires gradient.
+    distance                  : double
+                                Hologram plane distance wrt SLM plane.
+    pixel_pitch               : float
+                                SLM pixel pitch in meters.
+    propagation_type          : str
+                                Type of the propagation (see odak.learn.wave.propagate_beam()).
+    n_iteration:              : int
+                                Number of iteration.
+    loss_function:            : function
+                                If none it is set to be l2 loss.
+    learning_rate             : float
+                                Learning rate.
+
+    Returns
+    -------
+    hologram                  : torch.Tensor
+                                Phase only hologram as torch array
+
+    reconstruction_intensity  : torch.Tensor
+                                Reconstruction as torch array
+
+    """
+    phase = torch.randn_like(target, requires_grad = True)
+    k = wavenumber(wavelength)
+    optimizer = torch.optim.Adam([phase], lr = learning_rate)
+    if type(loss_function) == type(None):
+        loss_function = torch.nn.MSELoss()
+    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
+    for i in t:
+        optimizer.zero_grad()
+        hologram = generate_complex_field(1., phase)
+        reconstruction = propagate_beam(
+                                        hologram, 
+                                        k, 
+                                        distance, 
+                                        pixel_pitch, 
+                                        wavelength, 
+                                        propagation_type, 
+                                        zero_padding = [True, False, True]
+                                       )
+        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
+        loss = loss_function(reconstruction_intensity, target)
+        description = "Loss:{:.4f}".format(loss.item())
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description(description)
+    logging.warning(description)
+    torch.no_grad()
+    hologram = generate_complex_field(1., phase)
+    reconstruction = propagate_beam(
+                                    hologram, 
+                                    k, 
+                                    distance, 
+                                    pixel_pitch, 
+                                    wavelength, 
+                                    propagation_type, 
+                                    zero_padding = [True, False, True]
+                                   )
+    return hologram, reconstruction
+
+
+
+ +
+
+
+

Let us also examine the optimized hologram and the image that the hologram reconstructed at the image plane.

+
+

Image title +

+
Optimized phase-only hologram. Generated using "test/test_learn_wave_stochastic_gradient_descent.py".
+
+
+

Image title +

+
Optimized phase-only hologram reconstructed at the image plane, generated using "test/test_learn_wave_stochastic_gradient_descent.py".
+
+
+Challenge: Non-iterative Learned Hologram Calculation +

We provided an overview of optimizing holograms using iterative methods. +Iterative methods are computationally expensive and unsuitable for real-time hologram generation. +We challenge our readers to derive a learned hologram generation method for multiplane images (not single-plane like in our example). +This development could either rely on classical convolutional neural networks or blend with physical priors explained in this section. +The resultant method could be part of odak.learn.wave submodule as a new class odak.learn.wave.learned_hologram. +In addition, a unit test test/test_learn_hologram.py has to adopt this new class. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for arbitrary surfaces in docs/notes/learned_hologram_generation.md.

+
+

Simulating a standard holographic display

+

Informative · + Practical

+

We optimized holograms for a holographic display in the previous section. +However, the beam propagation distance we used in our optimization example was large. +If we were to run the same optimization for a shorter propagation distance, say not cms but mms, we would not get a decent solution. +Because in an actual holographic display, there is an aperture that helps to filter out some of the light. +The previous section contained an optical layout rendering of a holographic display, where this aperture is also depicted. +As depicted in the rendering located in the previous section, this aperture is located between a two lens system, which is also known as 4F imaging system.

+
+Did you know? +

4F imaging system can take a Fourier transform of an input field by using physics but not computers. +For more details, please review these course notes from MIT.

+
+

Let us review the class dedicated to accurately simulating a holographic display and its functions:

+
+
+
+ + +
+ + + + +
+ +

Internal function to reconstruct a given hologram.

+ + +

Parameters:

+
    +
  • + hologram_phases + – +
    +
                         Hologram phases [ch x m x n].
    +
    +
    +
  • +
  • + amplitude + – +
    +
                         Amplitude profiles for each color primary [ch x m x n]
    +
    +
    +
  • +
  • + no_grad + – +
    +
                         If set True, uses torch.no_grad in reconstruction.
    +
    +
    +
  • +
  • + get_complex + – +
    +
                         If set True, reconstructor returns the complex field but not the intensities.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +reconstructions ( tensor +) – +
    +

    Reconstructed frames.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):
+    """
+    Internal function to reconstruct a given hologram.
+
+
+    Parameters
+    ----------
+    hologram_phases            : torch.tensor
+                                 Hologram phases [ch x m x n].
+    amplitude                  : torch.tensor
+                                 Amplitude profiles for each color primary [ch x m x n]
+    no_grad                    : bool
+                                 If set True, uses torch.no_grad in reconstruction.
+    get_complex                : bool
+                                 If set True, reconstructor returns the complex field but not the intensities.
+
+    Returns
+    -------
+    reconstructions            : torch.tensor
+                                 Reconstructed frames.
+    """
+    if no_grad:
+        torch.no_grad()
+    if len(hologram_phases.shape) > 3:
+        hologram_phases = hologram_phases.squeeze(0)
+    if get_complex == True:
+        reconstruction_type = torch.complex64
+    else:
+        reconstruction_type = torch.float32
+    reconstructions = torch.zeros(
+                                  self.number_of_frames,
+                                  self.number_of_depth_layers,
+                                  self.number_of_channels,
+                                  self.resolution[0] * self.resolution_factor,
+                                  self.resolution[1] * self.resolution_factor,
+                                  dtype = reconstruction_type,
+                                  device = self.device
+                                 )
+    if isinstance(amplitude, type(None)):
+        amplitude = torch.zeros(
+                                self.number_of_channels,
+                                self.resolution[0] * self.resolution_factor,
+                                self.resolution[1] * self.resolution_factor,
+                                device = self.device
+                               )
+        amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
+    if self.resolution_factor != 1:
+        hologram_phases_scaled = torch.zeros_like(amplitude)
+        hologram_phases_scaled[
+                               :,
+                               ::self.resolution_factor,
+                               ::self.resolution_factor
+                              ] = hologram_phases
+    else:
+        hologram_phases_scaled = hologram_phases
+    for frame_id in range(self.number_of_frames):
+        for depth_id in range(self.number_of_depth_layers):
+            for channel_id in range(self.number_of_channels):
+                laser_power = self.get_laser_powers()[frame_id][channel_id]
+                phase = hologram_phases_scaled[frame_id]
+                hologram = generate_complex_field(
+                                                  laser_power * amplitude[channel_id],
+                                                  phase * self.phase_scale[channel_id]
+                                                 )
+                reconstruction_field = self.__call__(hologram, channel_id, depth_id)
+                if get_complex == True:
+                    result = reconstruction_field
+                else:
+                    result = calculate_amplitude(reconstruction_field) ** 2
+                reconstructions[
+                                frame_id,
+                                depth_id,
+                                channel_id
+                               ] = result.detach().clone()
+    return reconstructions
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Function that represents the forward model in hologram optimization.

+ + +

Parameters:

+
    +
  • + input_field + – +
    +
                  Input complex input field.
    +
    +
    +
  • +
  • + channel_id + – +
    +
                  Identifying the color primary to be used.
    +
    +
    +
  • +
  • + depth_id + – +
    +
                  Identifying the depth layer to be used.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_field ( tensor +) – +
    +

    Propagated output complex field.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def __call__(self, input_field, channel_id, depth_id):
+    """
+    Function that represents the forward model in hologram optimization.
+
+    Parameters
+    ----------
+    input_field         : torch.tensor
+                          Input complex input field.
+    channel_id          : int
+                          Identifying the color primary to be used.
+    depth_id            : int
+                          Identifying the depth layer to be used.
+
+    Returns
+    -------
+    output_field        : torch.tensor
+                          Propagated output complex field.
+    """
+    distance = self.distances[depth_id]
+    if not self.generated_kernels[depth_id, channel_id]:
+        if self.propagator_type == 'forward':
+            H = get_propagation_kernel(
+                                       nu = self.resolution[0] * 2,
+                                       nv = self.resolution[1] * 2,
+                                       dx = self.pixel_pitch,
+                                       wavelength = self.wavelengths[channel_id],
+                                       distance = distance,
+                                       device = self.device,
+                                       propagation_type = self.propagation_type,
+                                       samples = self.aperture_samples,
+                                       scale = self.resolution_factor
+                                      )
+        elif self.propagator_type == 'back and forth':
+            H_forward = get_propagation_kernel(
+                                               nu = self.resolution[0] * 2,
+                                               nv = self.resolution[1] * 2,
+                                               dx = self.pixel_pitch,
+                                               wavelength = self.wavelengths[channel_id],
+                                               distance = self.zero_mode_distance,
+                                               device = self.device,
+                                               propagation_type = self.propagation_type,
+                                               samples = self.aperture_samples,
+                                               scale = self.resolution_factor
+                                              )
+            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
+            H_back = get_propagation_kernel(
+                                            nu = self.resolution[0] * 2,
+                                            nv = self.resolution[1] * 2,
+                                            dx = self.pixel_pitch,
+                                            wavelength = self.wavelengths[channel_id],
+                                            distance = distance_back,
+                                            device = self.device,
+                                            propagation_type = self.propagation_type,
+                                            samples = self.aperture_samples,
+                                            scale = self.resolution_factor
+                                           )
+            H = H_forward * H_back
+        self.kernels[depth_id, channel_id] = H
+        self.generated_kernels[depth_id, channel_id] = True
+    else:
+        H = self.kernels[depth_id, channel_id].detach().clone()
+    field_scale = input_field
+    field_scale_padded = zero_pad(field_scale)
+    output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
+    output_field = crop_center(output_field_padded)
+    return output_field
+
+
+
+ +
+
+
+

This sample unit test provides an example use case of the holographic display class.

+
+
+
+

+
+
+
+
+

Let us also examine how the reconstructed images look like at the image plane.

+
+

Image title +

+
Reconstructed phase-only hologram at two image plane, generated using "test/test_learn_wave_holographic_display.py".
+
+

You may also be curious about how these holograms would look like in an actual holographic display, here we provide a sample gallery filled with photographs captured from our holographic display:

+
+

+ + + + + + + +

+
Photographs of holograms captured using the holographic display in Computational Light Laboratory
+
+

Conclusion

+

Informative

+

Holography offers new frontiers as an emerging method in simulating light for various applications, including displays and cameras. +We provide a basic introduction to Computer-Generated Holography and a simple understanding of holographic methods. +A motivated reader could scale up from this knowledge to advance concepts in displays, cameras, visual perception, optical computing, and many other light-based applications.

+
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+
+
+
    +
  1. +

    Max Born and Emil Wolf. Principles of optics: electromagnetic theory of propagation, interference and diffraction of light. Elsevier, 2013. 

    +
  2. +
  3. +

    Joseph W Goodman. Introduction to Fourier optics. Roberts and Company publishers, 2005. 

    +
  4. +
  5. +

    Koray Kavakli, David Robert Walton, Nick Antipa, Rafał Mantiuk, Douglas Lanman, and Kaan Akşit. Optimizing vision and visuals: lectures on cameras, displays and perception. In ACM SIGGRAPH 2022 Courses, pages 1–66. 2022. 

    +
  6. +
  7. +

    Jason D Schmidt. Numerical simulation of optical wave propagation with examples in matlab. (No Title), 2010. 

    +
  8. +
  9. +

    John C Heurtley. Scalar rayleigh–sommerfeld and kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8):1003–1008, 1973. 

    +
  10. +
  11. +

    Maciej Sypek. Light propagation in the fresnel region. new numerical approach. Optics communications, 116(1-3):43–48, 1995. 

    +
  12. +
  13. +

    Kyoji Matsushima and Tomoyoshi Shimobaba. Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields. Optics express, 17(22):19662–19673, 2009. 

    +
  14. +
  15. +

    Wenhui Zhang, Hao Zhang, and Guofan Jin. Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range. Optics letters, 45(6):1543–1546, 2020. 

    +
  16. +
  17. +

    Wenhui Zhang, Hao Zhang, and Guofan Jin. Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product. Optics Letters, 45(16):4416–4419, 2020. 

    +
  18. +
  19. +

    Yifan Peng, Suyeon Choi, Nitish Padmanaban, and Gordon Wetzstein. Neural holography with camera-in-the-loop training. ACM Transactions on Graphics (TOG), 39(6):1–14, 2020. 

    +
  20. +
  21. +

    Praneeth Chakravarthula, Ethan Tseng, Tarun Srivastava, Henry Fuchs, and Felix Heide. Learned hardware-in-the-loop phase retrieval for holographic near-eye displays. ACM Transactions on Graphics (TOG), 39(6):1–18, 2020. 

    +
  22. +
  23. +

    Koray Kavaklı, Hakan Urey, and Kaan Akşit. Learned holographic light transport. Applied Optics, 61(5):B50–B55, 2022. 

    +
  24. +
  25. +

    Gary J Swanson. Binary optics technology: the theory and design of multi-level diffractive optical elements. Technical Report, MASSACHUSETTS INST OF TECH LEXINGTON LINCOLN LAB, 1989. 

    +
  26. +
  27. +

    Herwig Kogelnik. Coupled wave theory for thick hologram gratings. Bell System Technical Journal, 48(9):2909–2947, 1969. 

    +
  28. +
  29. +

    Lingling Huang, Shuang Zhang, and Thomas Zentgraf. Metasurface holography: from fundamentals to applications. Nanophotonics, 7(6):1169–1190, 2018. 

    +
  30. +
  31. +

    Koray Kavaklı, Yuta Itoh, Hakan Urey, and Kaan Akşit. Realistic defocus blur for multiplane computer-generated holography. In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), 418–426. IEEE, 2023. 

    +
  32. +
+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/fundamentals/index.html b/course/fundamentals/index.html new file mode 100644 index 00000000..7b4b5228 --- /dev/null +++ b/course/fundamentals/index.html @@ -0,0 +1,3652 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Fundamentals in optimizing and learning light - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +
+Narrate section +

+
+

Fundamentals and Standards

+

This chapter will reveal some important basic information you will use in the rest of this course. +In addition, we will also introduce you to a structure where we establish some standards to decrease the chances of producing buggy or incompatible codes.

+

Required Production Environment

+

Informative · + Practical

+

We have provided some information in prerequisites. +This information includes programming language requirements, required libraries, text editors, build environments, and operating system requirements. +For installing our library, odak, we strongly advise using the version in the source repository. +You can install odak from the source repository using your favorite terminal and operating system:

+
pip3 install git+https://github.com/kaanaksit/odak
+
+

Note that your production environment meaning your computer and required software for this course is important. +To avoid wasting time in the next chapters and get the most from this lecture, please ensure that you have dedicated enough time to set everything up as it should.

+

Production Standards

+

Informative

+

In this course, you will be asked to code and implement simulations related to the physics of light. +Your work, meaning your production, should strictly follow certain habits to help build better tools and developments.

+

Subversion and Revision Control

+

Informative · + Practical

+

As you develop your code for your future homework and projects, you will discover that many things could go wrong. +For example, the hard drive that contains the only copy of your code could be damaged, or your most trusted friend (so-called) can claim that she compiled most of the work, and gets her credit for it, although that is not the case. +These are just a few potential cases that may happen to you. +On the other hand, in business life, poor code control can cause companies to lose money by releasing incorrect codes or researchers to lose their reputations as their work is challenging to replicate. +How do you claim in that case that you did your part? +What is the proper method to avoid losing data, time, effort, and motivation? +In short, what is the way to stay out of trouble?

+

This is where the subversion, authoring, and revision control systems come into play, especially, for the example cases discussed in the previous paragraph. +In today's world, Git is a widespread version control system adopted by major websites such as GitHub or Gitlab. +We will not dive deep into how to use Git and all its features, but I will try to highlight parts that are essential for your workflow. +I encourage you to use Git for creating a repository for every one of your tasks in the future. +You can either keep this repository in your local and constantly back up somewhere else (suggested to people knowing what they are doing) or use these online services such as GitHub or Gitlab. +I also encourage you to use the online services if you are a beginner.

+

For each operating system, installing Git has its processes, but for an Ubuntu operating system, it is as easy as typing the following commands in your terminal:

+
sudo apt install git
+
+

Let us imagine that you want to start a repository on GitHub. +Make sure to create a private repository, and please only go public with any repository once you feel it is at a state where it can be shared with others. +Once you have created your repository on GitHub, you can clone the repository using the following command in a terminal:

+
git clone REPLACEWITHLOCATIONOFREPO
+
+

You can find out about the repository's location by visiting the repository's website that you have created. +The location is typically revealed by clicking the code button, as depicted in the below screenshot.

+
+

Image title +

+
A screenshot showing how you can acquire the link for cloning a repository from GitHub.
+
+

For example, in the above case, the command should be updated with the following:

+
git clone https://github.com/kaanaksit/odak.git
+
+

If you want to share your private repository with someone you can go into the settings of your repository in its webpage and navigate to the collaborators section. +This way, you can assign roles to your collaborators that best suit your scenario.

+
+

Secure your account

+

If you are using GitHub for your development, I highly encourage you to consider using two-factor authentication.

+
+

Git Basics

+

Informative · + Practical

+

If you want to add new files to your subversion control system, use the following in a terminal:

+
git add YOURFILE.jpeg
+
+

You may want to track the status of the files (whether they are added, deleted, etc.) +

git status
+
+And later, you can update the online copy (remote server or source) using the following:

+
git commit -am "Explain what you add in a short comment."
+git push
+
+

In some cases, you may want to include large binary files in your project, such as a paper, video, or any other media you want to archive within your project repository. +For those cases, using just git may not be the best opinion, as Git works on creating a history of files and how they are changed at each commit, this history will likely be too bulky and oversized. +Thus, cloning a repository could be slow when large binary files and Git come together. +Assuming you are on an Ubuntu operating system, you can install the Large File Support (LFS) for Git by typing these commands in your terminal:

+
sudo apt install git-lfs
+
+

Once you have the LFS installed in your operating system, you can then go into your repository and enable LFS:

+
cd YOURREPOSITORY
+git lfs install
+
+

Now is the time to let your LFS track specific files to avoid overcrowding your Git history. +For example, you can track the *.pdf extension, meaning all the PDF files in your repository by typing the following command in your terminal:

+
git lfs track *.pdf
+
+

Finally, ensure the tracking information and LFS are copied to your remote/source repository. +You can do that using the following commands in your terminal:

+
git add .gitattributes
+git commit -am "Enabling large file support."
+git push
+
+

When projects expand in size, it's quite feasible for hundreds of individuals to collaborate within the same repository. +This is particularly prevalent in sizable software development initiatives or open-source projects with a substantial contributor base. +The branching system is frequently employed in these circumstances.

+

Consider you are in a software development team and you want to introduce new features or changes to a project without affecting the main or "master" branch. +You need to firstly create a new branch by using the following command which creates a new branch named BRANCHNAME but does not switch to it. +This new branch has the same contents as the current branch (a copy of the current branch).

+
git branch BRANCHNAME
+
+

Then you can switch to the new brach by using the command:

+
git checkout BRANCHNAME
+
+

Or use this command to create and switch to a new branch immediately

+
git checkout -b BRANCHNAME
+
+

After editing the new branch, you may want to update the changes to the master or main branch. +This command merges the branch named BRANCHNAME into the current branch. +You must resolve any conflicts to complete the merge.

+
git merge BRANCHNAME
+
+

We recommend an interactive, visual method for learning Git commands and branching online: learngitbranching. +More information can be found in the offical Git documentation: Git docs.

+

Coding Standards

+

Informative · + Practical

+

I encourage our readers to follow the methods of coding highlighted here. +Following the methods that I am going to explain is not only crucial for developing replicable projects, but it is also vital for allowing other people to read your code with the least amount of hassle.

+
+Where do I find out more about Python coding standards? +

Python Enhancement Proposals documentation provides a great deal of information on modern ways to code in Python.

+
+

Avoid using long lines.

+

Please avoid having too many characters in one line. +Let us start with a bad example:

+
def light_transport(wavelength, distances, resolution, propagation_type, polarization, input_field, output_field, angles):
+      pass
+      return results
+
+

As you can observe, the above function requires multiple inputs to be provided. +Try making the inputs more readable by breaking lines and in some cases, you can also provide the requested type for an input and a default value to guide your users:

+
def light_transport(
+                    wavelength,
+                    distances,
+                    resolution,
+                    propagation_type : str, 
+                    polarization = 'vertical',
+                    input_field = torch.rand(1, 1, 100, 100),
+                    output_field = torch.zeros(1, 1, 100, 100),
+                    angles= [0., 0., 0.]
+                   ):
+    pass
+    return results
+
+

Leave spaces between commands, variables, and functions

+

Please avoid writing code like a train of characters. +Here is a terrible coding example:

+
def addition(x,y,z):
+    result=2*y+z+x**2*3
+    return result
+
+

Please leave spaces after each comma, ,, and mathematical operation. +So now, we can correct the above example as in below:

+
def addition(x, y, z):
+    result = 2 * y + z + x ** 2 * 3
+    return result
+
+

Please also leave two lines of space between the two functions. +Here is a bad example again:

+
def add(x, y):
+    return x + y
+def multiply(x, y):
+    return x * y
+
+

Instead, it should be:

+
def add(x, y):
+    return x + y
+
+
+def multiply(x, y):
+    return x * y
+
+

Add documentation

+

For your code, please make sure to add the necessary documentation. +Here is a good example of doing that:

+
def add(x, y):
+    """
+    A function to add two values together.
+
+    Parameters
+    ==========
+    x         : float
+                First input value.
+    y         : float
+                Second input value.
+
+    Returns
+    =======
+    result    : float
+                Result of the addition.
+    """
+    result = x + y
+    return result
+
+

Use a code-style checker and validator

+

There are also code-style checkers and code validators that you can adapt to your workflows when coding. +One of these code-style checkers and validators I use in my projects is pyflakes. +On an Ubuntu operating system, you can install pyflakes easily by typing these commands into your terminal:

+
sudo apt install python3-pyflakes
+
+

It could tell you about missing imports or undefined or unused variables. +You can use it on any Python script very easily:

+
pyflakes3 sample.py
+
+

In addition, I use flake8 and autopep8 for standard code violations. +To learn more about these, please read the code section of the contribution guide.

+

Naming variables

+

When naming variables use lower case letters and make sure that the variables are named in an explanatory manner. +Please also always use underscore as a replacement of space. +For example if you are going to create a variable for storing reconstructed image at some image plane, you can name that variable as reconstructions_image_planes.

+

Use fewer imports

+

When it comes to importing libraries in your code, please make sure to use a minimal amount of libraries. +Using a few libraries can help you keep your code robust and working over newer generations of libraries. +Please stick to the libraries suggested in this course when coding for this course. +If you need access to some other library, please do let us know!

+

Fixing bugs

+

Often, you can encounter bugs in your code. +To fix your code in such cases, I would like you to consider using a method called Rubber duck debugging or Rubber ducking. +The basic idea is to be able to express your code to a third person or yourself line by line. +Explaining line by line could help you see what is wrong with your code. +I am sure there are many recipes for solving bugs in codes. +I tried introducing you to one that works for me.

+

Have a requirements.txt

+

Please also make sure to have a requirements.txt in the root directory of your repository. +For example, in this course your requirements.txt would look like this:

+
odak>=0.2.4
+torch 
+
+

This way, a future user of your code could install the required libraries by following a simple command in a terminal:

+
pip3 install -m requirements.txt 
+
+

Always use the same function for saving and loading

+

Most issues in every software project come from repetition. +Imagine if you want to save and load images inside a code after some processing. +If you rely on manually coding a save and load routine in every corner of the same code, it is likely that when you change one of these saving or loading routines, you must modify the others. +In other words, do not rediscover what you have already known. +Instead, turn it into a Lego brick you can use whenever needed. +For saving and loading images, please rely on functions in odak to avoid any issues. +For example, if I want to load a sample image called letter.jpeg, I can rely on this example:

+
import odak
+image = odak.learn.tools.load_image(
+                                    'letter.jpeg',
+                                    torch_style = True, # (1)
+                                    normalizeby = 255. # (2)
+                                   )
+
+
    +
  1. If you set this flag to True, the image will be loaded + as [ch x m x n], where ch represents the number of color channels (e.g., typically three). + In case of False, it will be loaded as [m x n x ch].
  2. +
  3. If you provide a floating number here, the image to be loaded will be divived with that number. + For example, if you have a 8-bit image (0-255) and if you provide normalizeby = 2.0, the maximum + value that you can expect is 255 / 2. = 127.5.
  4. +
+

Odak also provides a standard method for saving your torch tensors as image files:

+
odak.learn.tools.save_image(
+                            'copy.png',
+                            image,
+                            cmin = 0., # (1)
+                            cmax = 1., # (2)
+                            color_depth = 8 # (3)
+                           )
+
+
    +
  1. Minimum expected value for torch tensor image.
  2. +
  3. Maximum expected value for torch tensor image.
  4. +
  5. Pixel depth of the image to be saved, default is 8-bit.
  6. +
+

You may want to try the same code with different settings in some code development. +In those cases, I create a separate settings folder in the root directory of my projects and add JSON files that I can load for testing different cases. +To explain the case better, let us assume we will change the number of light sources in some simulations. +Let's first assume that we create a settings file as settings/experiment_000.txt in the root directory and fill it with the following content:

+
{
+  "light source" : {
+                    "count" : 5,
+                    "type"  : "LED"
+                   }
+}
+
+

In the rest of my code, I can read, modify and save JSON files using odak functions:

+
import odak
+settings = odak.tools.load_dictionary('./settings/experiment_000.txt')
+settings['light source']['count'] = 10
+odak.tools.save_dictionary(settings, './settings/experiment_000.txt')
+
+

This way, you do not have to memorize the variables you used for every experiment you conducted with the same piece of code. +You can have a dedicated settings file for each experiment.

+

Create unit tests

+

Suppose your project is a library containing multiple valuable functions for developing other projects. +In that case, I encourage you to create unit tests for your library so that whenever you update it, you can see if your updates break anything in that library. +For this purpose, consider creating a test directory in the root folder of your repository. +In that directory, you can create separate Python scripts for testing out various functions of your library. +Say there is a function called add in your project MOSTAWESOMECODEEVER, so your test script test/test_add.py should look like this:

+
import MOSTAWESOMECODEEVER
+
+def test():
+    ground_truth = 3 + 5
+    result = MOSTAWESOMECODEEVER.add(3, 5)
+    if ground_trurth == result:
+        assert True == True
+    assert False == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+

You may accumulate various unit tests in your test directory. +To test them all before pushing them to your repository, you can rely on pytest. +You can install pytest using the following command in your terminal:

+
pip3 install pytest
+
+

Once installed, you can navigate to your repository's root directory and call pytest to test things out:

+
cd MOSTAWESOMECODEEVER
+pytest
+
+

If anything is wrong with your unit tests, which validate your functions, pytest will provide a detailed explanation.Suppose your project is a library containing multiple valuable functions for developing other projects. +In that case, I encourage you to create unit tests for your library so that whenever you update it, you can see if your updates break anything in that library. +For this purpose, consider creating a test directory in the root folder of your repository. +In that directory, you can create separate Python scripts for testing out various functions of your library. +Say there is a function called add in your project MOSTAWESOMECODEEVER, so your test script test/test_add.py should look like this:

+
import MOSTAWESOMECODEEVER
+
+def test():
+    ground_truth = 3 + 5
+    result = MOSTAWESOMECODEEVER.add(3, 5)
+    if ground_trurth == result:
+        assert True == True
+    assert False == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+

You may accumulate various unit tests in your test directory. +To test them all before pushing them to your repository, you can rely on pytest. +You can install pytest using the following command in your terminal:

+
pip3 install pytest
+
+

Once installed, you can navigate to your repository's root directory and call pytest to test things out:

+
cd MOSTAWESOMECODEEVER
+pytest
+
+

If anything is wrong with your unit tests, which validate your functions, pytest will provide a detailed explanation.

+

Set a licence

+

If you want to distribute your code online, consider adding a license to avoid having difficulties related to sharing with others. +In other words, you can add LICENSE.txt in the root directory of your repository. +To determine which license works best for you, consider visiting this guideline. +When choosing a license for your project, consider tinkering about whether you agree people are building a product out of your work or derivate, etc.

+
+Lab work: Prepare a project repository +

Please prepare a sample repository on GitHub using the information provided in the above sections. +Here are some sample files that may inspire you and help you structure your project in good order:

+
+
+
+
import odak
+import torch
+import sys
+
+
+def main():
+    print('your codebase')
+
+
+if __name__ == '__main__':
+    sys.exit(main())
+
+
+
+
LICENSE.txt
Mozilla Public License Version 2.0
+==================================
+
+1. Definitions
+--------------
+
+1.1. "Contributor"
+    means each individual or legal entity that creates, contributes to
+    the creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+    means the combination of the Contributions of others (if any) used
+    by a Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+    means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+    means Source Code Form to which the initial Contributor has attached
+    the notice in Exhibit A, the Executable Form of such Source Code
+    Form, and Modifications of such Source Code Form, in each case
+    including portions thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+    means
+
+    (a) that the initial Contributor has attached the notice described
+        in Exhibit B to the Covered Software; or
+
+    (b) that the Covered Software was made available under the terms of
+        version 1.1 or earlier of the License, but not also under the
+        terms of a Secondary License.
+
+1.6. "Executable Form"
+    means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+    means a work that combines Covered Software with other material, in 
+    a separate file or files, that is not Covered Software.
+
+1.8. "License"
+    means this document.
+
+1.9. "Licensable"
+    means having the right to grant, to the maximum extent possible,
+    whether at the time of the initial grant or subsequently, any and
+    all of the rights conveyed by this License.
+
+1.10. "Modifications"
+    means any of the following:
+
+    (a) any file in Source Code Form that results from an addition to,
+        deletion from, or modification of the contents of Covered
+        Software; or
+
+    (b) any new file in Source Code Form that contains any Covered
+        Software.
+
+1.11. "Patent Claims" of a Contributor
+    means any patent claim(s), including without limitation, method,
+    process, and apparatus claims, in any patent Licensable by such
+    Contributor that would be infringed, but for the grant of the
+    License, by the making, using, selling, offering for sale, having
+    made, import, or transfer of either its Contributions or its
+    Contributor Version.
+
+1.12. "Secondary License"
+    means either the GNU General Public License, Version 2.0, the GNU
+    Lesser General Public License, Version 2.1, the GNU Affero General
+    Public License, Version 3.0, or any later versions of those
+    licenses.
+
+1.13. "Source Code Form"
+    means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+    means an individual or a legal entity exercising rights under this
+    License. For legal entities, "You" includes any entity that
+    controls, is controlled by, or is under common control with You. For
+    purposes of this definition, "control" means (a) the power, direct
+    or indirect, to cause the direction or management of such entity,
+    whether by contract or otherwise, or (b) ownership of more than
+    fifty percent (50%) of the outstanding shares or beneficial
+    ownership of such entity.
+
+2. License Grants and Conditions
+--------------------------------
+
+2.1. Grants
+
+Each Contributor hereby grants You a world-wide, royalty-free,
+non-exclusive license:
+
+(a) under intellectual property rights (other than patent or trademark)
+    Licensable by such Contributor to use, reproduce, make available,
+    modify, display, perform, distribute, and otherwise exploit its
+    Contributions, either on an unmodified basis, with Modifications, or
+    as part of a Larger Work; and
+
+(b) under Patent Claims of such Contributor to make, use, sell, offer
+    for sale, have made, import, and otherwise transfer either its
+    Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+The licenses granted in Section 2.1 with respect to any Contribution
+become effective for each Contribution on the date the Contributor first
+distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+The licenses granted in this Section 2 are the only rights granted under
+this License. No additional rights or licenses will be implied from the
+distribution or licensing of Covered Software under this License.
+Notwithstanding Section 2.1(b) above, no patent license is granted by a
+Contributor:
+
+(a) for any code that a Contributor has removed from Covered Software;
+    or
+
+(b) for infringements caused by: (i) Your and any other third party's
+    modifications of Covered Software, or (ii) the combination of its
+    Contributions with other software (except as part of its Contributor
+    Version); or
+
+(c) under Patent Claims infringed by Covered Software in the absence of
+    its Contributions.
+
+This License does not grant any rights in the trademarks, service marks,
+or logos of any Contributor (except as may be necessary to comply with
+the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+No Contributor makes additional grants as a result of Your choice to
+distribute the Covered Software under a subsequent version of this
+License (see Section 10.2) or under the terms of a Secondary License (if
+permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+Each Contributor represents that the Contributor believes its
+Contributions are its original creation(s) or it has sufficient rights
+to grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+This License is not intended to limit any rights You have under
+applicable copyright doctrines of fair use, fair dealing, or other
+equivalents.
+
+2.7. Conditions
+
+Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
+in Section 2.1.
+
+3. Responsibilities
+-------------------
+
+3.1. Distribution of Source Form
+
+All distribution of Covered Software in Source Code Form, including any
+Modifications that You create or to which You contribute, must be under
+the terms of this License. You must inform recipients that the Source
+Code Form of the Covered Software is governed by the terms of this
+License, and how they can obtain a copy of this License. You may not
+attempt to alter or restrict the recipients' rights in the Source Code
+Form.
+
+3.2. Distribution of Executable Form
+
+If You distribute Covered Software in Executable Form then:
+
+(a) such Covered Software must also be made available in Source Code
+    Form, as described in Section 3.1, and You must inform recipients of
+    the Executable Form how they can obtain a copy of such Source Code
+    Form by reasonable means in a timely manner, at a charge no more
+    than the cost of distribution to the recipient; and
+
+(b) You may distribute such Executable Form under the terms of this
+    License, or sublicense it under different terms, provided that the
+    license for the Executable Form does not attempt to limit or alter
+    the recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+You may create and distribute a Larger Work under terms of Your choice,
+provided that You also comply with the requirements of this License for
+the Covered Software. If the Larger Work is a combination of Covered
+Software with a work governed by one or more Secondary Licenses, and the
+Covered Software is not Incompatible With Secondary Licenses, this
+License permits You to additionally distribute such Covered Software
+under the terms of such Secondary License(s), so that the recipient of
+the Larger Work may, at their option, further distribute the Covered
+Software under the terms of either this License or such Secondary
+License(s).
+
+3.4. Notices
+
+You may not remove or alter the substance of any license notices
+(including copyright notices, patent notices, disclaimers of warranty,
+or limitations of liability) contained within the Source Code Form of
+the Covered Software, except that You may alter any license notices to
+the extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+You may choose to offer, and to charge a fee for, warranty, support,
+indemnity or liability obligations to one or more recipients of Covered
+Software. However, You may do so only on Your own behalf, and not on
+behalf of any Contributor. You must make it absolutely clear that any
+such warranty, support, indemnity, or liability obligation is offered by
+You alone, and You hereby agree to indemnify every Contributor for any
+liability incurred by such Contributor as a result of warranty, support,
+indemnity or liability terms You offer. You may include additional
+disclaimers of warranty and limitations of liability specific to any
+jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+---------------------------------------------------
+
+If it is impossible for You to comply with any of the terms of this
+License with respect to some or all of the Covered Software due to
+statute, judicial order, or regulation then You must: (a) comply with
+the terms of this License to the maximum extent possible; and (b)
+describe the limitations and the code they affect. Such description must
+be placed in a text file included with all distributions of the Covered
+Software under this License. Except to the extent prohibited by statute
+or regulation, such description must be sufficiently detailed for a
+recipient of ordinary skill to be able to understand it.
+
+5. Termination
+--------------
+
+5.1. The rights granted under this License will terminate automatically
+if You fail to comply with any of its terms. However, if You become
+compliant, then the rights granted under this License from a particular
+Contributor are reinstated (a) provisionally, unless and until such
+Contributor explicitly and finally terminates Your grants, and (b) on an
+ongoing basis, if such Contributor fails to notify You of the
+non-compliance by some reasonable means prior to 60 days after You have
+come back into compliance. Moreover, Your grants from a particular
+Contributor are reinstated on an ongoing basis if such Contributor
+notifies You of the non-compliance by some reasonable means, this is the
+first time You have received notice of non-compliance with this License
+from such Contributor, and You become compliant prior to 30 days after
+Your receipt of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+infringement claim (excluding declaratory judgment actions,
+counter-claims, and cross-claims) alleging that a Contributor Version
+directly or indirectly infringes any patent, then the rights granted to
+You by any and all Contributors for the Covered Software under Section
+2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all
+end user license agreements (excluding distributors and resellers) which
+have been validly granted by You or Your distributors under this License
+prior to termination shall survive termination.
+
+************************************************************************
+*                                                                      *
+*  6. Disclaimer of Warranty                                           *
+*  -------------------------                                           *
+*                                                                      *
+*  Covered Software is provided under this License on an "as is"       *
+*  basis, without warranty of any kind, either expressed, implied, or  *
+*  statutory, including, without limitation, warranties that the       *
+*  Covered Software is free of defects, merchantable, fit for a        *
+*  particular purpose or non-infringing. The entire risk as to the     *
+*  quality and performance of the Covered Software is with You.        *
+*  Should any Covered Software prove defective in any respect, You     *
+*  (not any Contributor) assume the cost of any necessary servicing,   *
+*  repair, or correction. This disclaimer of warranty constitutes an   *
+*  essential part of this License. No use of any Covered Software is   *
+*  authorized under this License except under this disclaimer.         *
+*                                                                      *
+************************************************************************
+
+************************************************************************
+*                                                                      *
+*  7. Limitation of Liability                                          *
+*  --------------------------                                          *
+*                                                                      *
+*  Under no circumstances and under no legal theory, whether tort      *
+*  (including negligence), contract, or otherwise, shall any           *
+*  Contributor, or anyone who distributes Covered Software as          *
+*  permitted above, be liable to You for any direct, indirect,         *
+*  special, incidental, or consequential damages of any character      *
+*  including, without limitation, damages for lost profits, loss of    *
+*  goodwill, work stoppage, computer failure or malfunction, or any    *
+*  and all other commercial damages or losses, even if such party      *
+*  shall have been informed of the possibility of such damages. This   *
+*  limitation of liability shall not apply to liability for death or   *
+*  personal injury resulting from such party's negligence to the       *
+*  extent applicable law prohibits such limitation. Some               *
+*  jurisdictions do not allow the exclusion or limitation of           *
+*  incidental or consequential damages, so this exclusion and          *
+*  limitation may not apply to You.                                    *
+*                                                                      *
+************************************************************************
+
+8. Litigation
+-------------
+
+Any litigation relating to this License may be brought only in the
+courts of a jurisdiction where the defendant maintains its principal
+place of business and such litigation shall be governed by laws of that
+jurisdiction, without reference to its conflict-of-law provisions.
+Nothing in this Section shall prevent a party's ability to bring
+cross-claims or counter-claims.
+
+9. Miscellaneous
+----------------
+
+This License represents the complete agreement concerning the subject
+matter hereof. If any provision of this License is held to be
+unenforceable, such provision shall be reformed only to the extent
+necessary to make it enforceable. Any law or regulation which provides
+that the language of a contract shall be construed against the drafter
+shall not be used to construe this License against a Contributor.
+
+10. Versions of the License
+---------------------------
+
+10.1. New Versions
+
+Mozilla Foundation is the license steward. Except as provided in Section
+10.3, no one other than the license steward has the right to modify or
+publish new versions of this License. Each version will be given a
+distinguishing version number.
+
+10.2. Effect of New Versions
+
+You may distribute the Covered Software under the terms of the version
+of the License under which You originally received the Covered Software,
+or under the terms of any subsequent version published by the license
+steward.
+
+10.3. Modified Versions
+
+If you create software not governed by this License, and you want to
+create a new license for such software, you may create and use a
+modified version of this License if you rename the license and remove
+any references to the name of the license steward (except to note that
+such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+Licenses
+
+If You choose to distribute Source Code Form that is Incompatible With
+Secondary Licenses under the terms of this version of the License, the
+notice described in Exhibit B of this License must be attached.
+
+Exhibit A - Source Code Form License Notice
+-------------------------------------------
+
+  This Source Code Form is subject to the terms of the Mozilla Public
+  License, v. 2.0. If a copy of the MPL was not distributed with this
+  file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular
+file, then You may include the notice in a location (such as a LICENSE
+file in a relevant directory) where a recipient would be likely to look
+for such a notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+---------------------------------------------------------
+
+  This Source Code Form is "Incompatible With Secondary Licenses", as
+  defined by the Mozilla Public License, v. 2.0.
+
+
+
+
requirements.txt
opencv-python>=4.10.0.84
+numpy>=1.26.4
+torch>=2.3.0
+plyfile>=1.0.3
+tqdm>=4.66.4
+
+
+
+
THANKS.txt
Ahmet Hamdi Güzel
+Ahmet Serdar Karadeniz
+David Robert Walton
+David Santiago Morales Norato
+Henry Kam
+Doğa Yılmaz
+Jeanne Beyazian
+Jialun Wu
+Josef Spjut
+Koray Kavaklı
+Liang Shi
+Mustafa Doğa Doğan
+Praneeth Chakravarthula
+Runze Zhu
+Weijie Xie
+Yujie Wang
+Yuta Itoh
+Ziyang Chen
+Yicheng Zhan
+
+
+
+
CODE_OF_CONDUCT.md
# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+  and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+  overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+  advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+  address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+  professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior,  harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
+
+
+
+
+
+

Background Review

+

Informative · + Media

+

Here, I will review some basic mathematical concepts using equations, images, or codes. +Please note that you must understand these concepts to avoid difficulty following this course.

+

Convolution Operation

+

Convolution is a mathematical operation used as a building block for describing systems. +It has proven to be highly effective in machine learning and deep learning. +Convolution operation often denoted with a * symbol. +Assume there is a matrix, A, which we want to convolve with some other matrix, also known as the kernel, K.

+
+

Image title +

+
A sketch showing a matrix and a kernel to be convolved.
+
+

One can define such a matrix and a kernel using Torch in Python:

+
a = torch.tensor(
+                 [
+                  [1, 5, 9, 2, 3],
+                  [4, 8, 2, 3, 6],
+                  [7, 2, 0, 1, 3],
+                  [9, 6, 4, 2, 5],
+                  [2, 3, 5, 7, 4]
+                 ]
+                )
+k = torch.tensor(
+                 [
+                  [-1, 2, -3], 
+                  [ 3, 5,  7], 
+                  [-4, 9, -2]
+                 ]
+                )
+
+

To convolve these two matrices without losing information, we first have to go through a mathematical operation called zero padding.

+
+

Image title +

+
A sketch showing zeropadding operating on a matrix.
+
+

To zeropad the matrix A, you can rely on Odak:

+
import odak
+
+a_zeropad = odak.learn.tools.zero_pad(a, size = [7, 7])
+
+

Note that we pass here size as [7, 7], the logic of this is very simple. +Our original matrix was five by five if you add a zero along two axis, you get seven by seven as the new requested size. +Also note that our kernel is three by three. +There could be cases where there is a larger kernel size. +In those cases, you want to zeropad half the size of kernel (e.g., original size plus half the kernel size, a.shape[0] + k.shape[0] // 2). +Now we choose the first element in the original matrix A, multiply it with the kernel, and add it to a matrix R. +But note that we add the results of our summation by centring it with the original location of the first element.

+
+

Image title +

+
A sketch showing the first step of a convolution operation.
+
+

We have to repeat this operation for each element in our original matrix and accummulate a result.

+
+

Image title +

+
A sketch showing the second step of a convolution operation.
+
+

Note that there are other ways to describe and implement the convolution operation. +Thus far, this definition formulates a simplistic description for convolution.

+
+Lab work: Implement convolution operation using Numpy +

There are three possible ways to implement convolution operation on a computer. +The first one involves loops visiting each point in a given data. +The second involves formulating a convolution operation as matrix multiplication, and the final one involves implementing convolution as a multiplication operation in the Fourier domain. +Implement all these three methods using Jupyter Notebooks and visually prove that they are all functioning correctly with various kernels (e.g., convolving image with a kernel). + Listed source files below may inspire your implementation in various means. + Note that the below code is based on Torch but not Numpy.

+
+
+
+ + +
+ + + + +
+ +

Definition to convolve a field with a kernel by multiplying in frequency space.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field with MxN shape.
    +
    +
    +
  • +
  • + kernel + – +
    +
          Input kernel with MxN shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( tensor +) – +
    +

    Convolved field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def convolve2d(field, kernel):
+    """
+    Definition to convolve a field with a kernel by multiplying in frequency space.
+
+    Parameters
+    ----------
+    field       : torch.tensor
+                  Input field with MxN shape.
+    kernel      : torch.tensor
+                  Input kernel with MxN shape.
+
+    Returns
+    ----------
+    new_field   : torch.tensor
+                  Convolved field.
+    """
+    fr = torch.fft.fft2(field)
+    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))
+    m, n = fr.shape
+    new_field = torch.real(torch.fft.ifft2(fr*fr2))
+    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))
+    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))
+    return new_field
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + mu + – +
    +
            Mu of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + normalize + – +
    +
            If set True, normalize the output.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( tensor +) – +
    +

    Generated Gaussian kernel.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):
+    """
+    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+    mu            : list
+                    Mu of the Gaussian kernel along X and Y axes.
+    normalize     : bool
+                    If set True, normalize the output.
+
+    Returns
+    ----------
+    kernel_2d     : torch.tensor
+                    Generated Gaussian kernel.
+    """
+    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])
+    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    if nsigma[0] == 0:
+        nsigma[0] = 1e-5
+    if nsigma[1] == 0:
+        nsigma[1] = 1e-5
+    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))
+    if normalize:
+        kernel_2d = kernel_2d / kernel_2d.max()
+    return kernel_2d
+
+
+
+ +
+
+
animation_convolution.py
import odak
+import torch
+import sys
+
+
+def main():
+    filename_image = '../media/10591010993_80c7cb37a6_c.jpg'
+    image = odak.learn.tools.load_image(filename_image, normalizeby = 255., torch_style = True)[0:3].unsqueeze(0)
+    kernel = odak.learn.tools.generate_2d_gaussian(kernel_length = [12, 12], nsigma = [21, 21])
+    kernel = kernel / kernel.max()
+    result = torch.zeros_like(image)
+    result = odak.learn.tools.zero_pad(result, size = [image.shape[-2] + kernel.shape[0], image.shape[-1] + kernel.shape[1]])
+    step = 0
+    for i in range(image.shape[-2]):
+        for j in range(image.shape[-1]):
+            for ch in range(image.shape[-3]):
+                element = image[:, ch, i, j]
+                add = kernel * element
+                result[:, ch, i : i + kernel.shape[0], j : j + kernel.shape[1]] += add
+            if (i * image.shape[-1] + j) % 1e4 == 0:
+                filename = 'step_{:04d}.png'.format(step)
+                odak.learn.tools.save_image( filename, result, cmin = 0., cmax = 100.)
+                step += 1
+    cmd = ['convert', '-delay', '1', '-loop', '0', '*.png', '../media/convolution_animation.gif']
+    odak.tools.shell_command(cmd)
+    cmd = ['rm', '*.png']
+    odak.tools.shell_command(cmd)
+
+
+if __name__ == '__main__':
+    sys.exit(main())
+
+
+
+
+
+

In summary, the convolution operation is heavily used in describing optical systems, computer vision-related algorithms, and state-of-the-art machine learning techniques. +Thus, understanding this mathematical operation is extremely important not only for this course but also for undergraduate and graduate-level courses. +As an example, let's see step by step how a sample image provided below is convolved:

+
+

Image title +

+
An animation showing the steps of convolution operation.
+
+

and the original image is as below:

+
+

Image title +

+
Original image before the convolution operation (Generated by Stable Diffusion).
+
+

Note that the source image shown above is generated with a generative model. +As a side note, I strongly suggest you to have familiarity with several models for generating test images, audio or any other type of media. +This way, you can remove your dependency to others in various means.

+
+Lab work: Convolve an image with a Gaussian kernel +

Using Odak and Torch, blur an image using a Gaussian kernel. +Also try compiling an animation like the one shown above using Matplotlib. +Use the below solution as a last resort, try compiling your code. +The code below is tested under Ubuntu operating system.

+
+
+
+
animation_convolution.py
import odak
+import torch
+import sys
+
+
+def main():
+    filename_image = '../media/10591010993_80c7cb37a6_c.jpg'
+    image = odak.learn.tools.load_image(filename_image, normalizeby = 255., torch_style = True)[0:3].unsqueeze(0)
+    kernel = odak.learn.tools.generate_2d_gaussian(kernel_length = [12, 12], nsigma = [21, 21])
+    kernel = kernel / kernel.max()
+    result = torch.zeros_like(image)
+    result = odak.learn.tools.zero_pad(result, size = [image.shape[-2] + kernel.shape[0], image.shape[-1] + kernel.shape[1]])
+    step = 0
+    for i in range(image.shape[-2]):
+        for j in range(image.shape[-1]):
+            for ch in range(image.shape[-3]):
+                element = image[:, ch, i, j]
+                add = kernel * element
+                result[:, ch, i : i + kernel.shape[0], j : j + kernel.shape[1]] += add
+            if (i * image.shape[-1] + j) % 1e4 == 0:
+                filename = 'step_{:04d}.png'.format(step)
+                odak.learn.tools.save_image( filename, result, cmin = 0., cmax = 100.)
+                step += 1
+    cmd = ['convert', '-delay', '1', '-loop', '0', '*.png', '../media/convolution_animation.gif']
+    odak.tools.shell_command(cmd)
+    cmd = ['rm', '*.png']
+    odak.tools.shell_command(cmd)
+
+
+if __name__ == '__main__':
+    sys.exit(main())
+
+
+
+
+
+

Gradient Descent Optimizers

+

Throughout this course, we will have to optimize variables to generate a solution for our problems. +Thus, we need a scalable method to optimize various variables in future problems and tasks. +We will not review optimizers in this section but provide a working solution. +You can learn more about optimizers through other courses offered within our curriculum or through suggested readings. +State-of-the-art Gradient Descent (GD) optimizers could play a key role here. +Significantly, Stochastic Gradient Descent (SGD) optimizers can help resolve our problems in the future with a reasonable memory footprint. +This is because GD updates its weights by visiting every sample in a dataset, whereas SGD can update using only randomly chosen data from that dataset. +Thus, SGD requires less memory for each update.

+
+Where can I read more about the state-of-the-art Stochastic Gradient Descent optimizer? +

To learn more, please read Paszke, Adam, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. "Automatic differentiation in pytorch." (2017). 1

+
+
+Would you like to code your Gradient Descent based optimizer ground up? +

In case you are interested in coding your Gradient Descent-based optimizer from the ground up, consider watching this tutorial online where I code the optimizer using only Numpy: +

+ If you want to learn more about odak's built-in functions on the matter, visit the below unit test script:

+
+
+
+
test_fit_gradient_descent_1d.py
import numpy as np
+import sys
+import odak
+
+
+def gradient_function(x, y, function, parameters):
+    solution = function(x, parameters)
+    gradient = np.array([
+                         -2 * x**2 * (y - solution),
+                         -2 * x * (y- solution),
+                         -2 * (y - solution)
+                        ])
+    return gradient
+
+
+def function(x, parameters):
+    y = parameters[0] * x**2 + parameters[1] * x + parameters[2]
+    return y
+
+
+def l2_loss(a, b):
+    loss = np.sum((a - b)**2)
+    return loss
+
+
+def test():
+    x = np.linspace(0, 1., 20) 
+    y = function(x, parameters=[2., 1., 10.])
+
+    learning_rate = 5e-1
+    iteration_number = 2000
+    initial_parameters = np.array([10., 10., 0.])
+    estimated_parameters = odak.fit.gradient_descent_1d(
+                                                        input_data=x,
+                                                        ground_truth_data=y,
+                                                        function=function,
+                                                        loss_function=l2_loss,
+                                                        gradient_function=gradient_function,
+                                                        parameters=initial_parameters,
+                                                        learning_rate=learning_rate,
+                                                        iteration_number=iteration_number
+                                                       )
+    assert True == True
+
+
+if __name__ == '__main__':
+   sys.exit(test())
+
+
+
+
+
+

Torch is a blessing for people that optimizes or trains with their algorithm. +Torch also comes with a set of state-of-the-art optimizers. +One of these optimizers is called the ADAM optimizer, torch.optim.Adam. +Let's observe the below example to make sense of how this optimizer can help us to optimize various variables.

+
import torch
+import odak  
+import sys # (1)
+
+
+def forward(x, m, n): # (2)
+    y = m * x + n
+    return y
+
+
+def main():
+    m = torch.tensor([100.], requires_grad = True)
+    n = torch.tensor([0.], requires_grad = True) # (3)
+    x_vals = torch.tensor([1., 2., 3., 100.])
+    y_vals = torch.tensor([5., 6., 7., 101.]) # (4)
+    optimizer = torch.optim.Adam([m, n], lr = 5e1) # (5)
+    loss_function = torch.nn.MSELoss() # (6)
+    for step in range(1000):
+        optimizer.zero_grad() # (7)
+        y_estimate = forward(x_vals, m, n) # (8)
+        loss = loss_function(y_estimate, y_vals) # (9)
+        loss.backward(retain_graph = True)
+        optimizer.step() # (10)
+        print('Step: {}, Loss: {}'.format(step, loss.item()))
+    print(m, n)
+
+
+if __name__ == '__main__':
+    sys.exit(main())
+
+
    +
  1. Required libraries are imported.
  2. +
  3. Let's assume that we are aiming to fit a line to some data (y = mx + n).
  4. +
  5. As we are aiming to fit a line, we have to find a proper m and n for our line (y = mx + n). + Pay attention to the fact that we have to make these variables differentiable by setting requires_grad = True.
  6. +
  7. Here is a sample dataset of X and Y values.
  8. +
  9. We define an Adam optimizer and ask our optimizer to optimize m and n.
  10. +
  11. We need some metric to identify if we are optimizer is optimizing correctly. + Here, we choose a L2 norm (least mean square) as our metric.
  12. +
  13. We clear graph before each iteration.
  14. +
  15. We make our estimation for Y values using the most current m and n values suggested by the optimizer.
  16. +
  17. We compare our estimation with original Y values to help our optimizer update m and n values.
  18. +
  19. Loss and optimizer help us move in the right direction for updating m and n values.
  20. +
+

Conclusion

+

We covered a lot of grounds in terms of coding standards, how to organize a project repository, and how basic things work in odak and Torch. +Please ensure you understand the essential information in this section. +Please note that we will use this information in this course's following sections and stages.

+
+Consider revisiting this chapter +

Remember that you can always revisit this chapter as you progress with the course and as you need it. +This chapter is vital for establishing a means to complete your assignments and could help formulate a suitable base to collaborate and work with my research group in the future or other experts in the field.

+
+
+Did you know that Computer Science misses basic tool education? +

The classes that Computer Science programs offer around the globe are commonly missing basic tool education. +Students often spend a large amount of time to learn tools while they are also learning an advanced topic. +This section of our course gave you a quick overview. +But you may want to go beyond and learn more about many more basic aspects of Computer Science such as using shell tools, editors, metaprogramming or security. +The missing semester of your CS education offers an online resource for you to follow up and learn more. +The content of the mentioned course is mostly developed by instructors from Massachusetts Institute of Technology.

+
+
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+
+
+
    +
  1. +

    Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. NIPS 2017 Workshop Autodiff, 2017. 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/geometric_optics/index.html b/course/geometric_optics/index.html new file mode 100644 index 00000000..23386ea1 --- /dev/null +++ b/course/geometric_optics/index.html @@ -0,0 +1,4399 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Modeling light with rays - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +
+Narrate section +

+
+

Modeling light with rays

+

Modeling light plays a crucial role in describing events based on light and helps designing mechanisms based on light (e.g., Realistic graphics in a video game, display or camera). +This chapter introduces the most basic description of light using geometric rays, also known as raytracing. +Raytracing has a long history, from ancient times to current Computer Graphics. +Here, we will not cover the history of raytracing. +Instead, we will focus on how we implement simulations to build "things" with raytracing in the future. +As we provide algorithmic examples to support our descriptions, readers should be able to simulate light on their computers using the provided descriptions.

+
+Are there other good resources on modeling light with rays? +

When I first started coding Odak, the first paper I read was on raytracing. +Thus, I recommend that paper for any starter:

+ +

Beyond this paper, there are several resources that I can recommend for curious readers:

+ +
+

Ray description

+

Informative · + Practical

+

We have to define what "a ray" is. +A ray has a starting point in Euclidean space (\(x_0, y_0, z_0 \in \mathbb{R}\)). +We also have to define direction cosines to provide the directions for rays. +Direction cosines are three angles of a ray between the XYZ axis and that ray (\(\theta_x, \theta_y, \theta_z \in \mathbb{R}\)). +To calculate direction cosines, we must choose a point on that ray as \(x_1, y_1,\) and \(z_1\) and we calculate its distance to the starting point of \(x_0, y_0\) and \(z_0\):

+
\[ +x_{distance} = x_1 - x_0, \\ +y_{distance} = y_1 - y_0, \\ +z_{distance} = z_1 - z_0. +\]
+

Then, we can also calculate the Euclidian distance between starting point and the point chosen:

+
\[ +s = \sqrt{x_{distance}^2 + y_{distance}^2 + z_{distance}^2}. +\]
+

Thus, we describe each direction cosines as:

+
\[ +cos(\theta_x) = \frac{x_{distance}}{s}, \\ +cos(\theta_y) = \frac{y_{distance}}{s}, \\ +cos(\theta_z) = \frac{z_{distance}}{s}. +\]
+

Now that we know how to define a ray with a starting point, \(x_0, y_0, z_0\) and a direction cosine, \(cos(\theta_x), cos(\theta_y), cos(\theta_z)\), let us carefully analyze the parameters, returns, and source code of the provided two following functions in odak dedicated to creating a ray or multiple rays.

+
+
+
+ + +
+ + + + +
+ +

Definition to create a ray.

+ + +

Parameters:

+
    +
  • + xyz + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + abg + – +
    +
           List that contains angles in degrees with respect to the X,Y and Z axes.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + direction + – +
    +
           If set to True, cosines of `abg` is not calculated.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray. +Size will be either [1 x 3] or [m x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
def create_ray(xyz, abg, direction = False):
+    """
+    Definition to create a ray.
+
+    Parameters
+    ----------
+    xyz          : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    abg          : torch.tensor
+                   List that contains angles in degrees with respect to the X,Y and Z axes.
+                   Size could be [1 x 3], [3], [m x 3].
+    direction    : bool
+                   If set to True, cosines of `abg` is not calculated.
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray.
+                   Size will be either [1 x 3] or [m x 3].
+    """
+    points = xyz
+    angles = abg
+    if len(xyz) == 1:
+        points = xyz.unsqueeze(0)
+    if len(abg) == 1:
+        angles = abg.unsqueeze(0)
+    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)
+    ray[:, 0] = points
+    if direction:
+        ray[:, 1] = abg
+    else:
+        ray[:, 1] = torch.cos(torch.deg2rad(abg))
+    return ray
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray or batch of rays.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray(s).

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    x1y1z1       : torch.tensor
+                   List that contains X,Y and Z ending locations of a ray or batch of rays.
+                   Size could be [1 x 3], [3], [m x 3].
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray(s).
+    """
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.unsqueeze(0)
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.unsqueeze(0)
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5
+    s[s == 0] = float('nan')
+    cosines = torch.zeros_like(x0y0z0 * x1y1z1)
+    cosines[:, 0] = xdiff / s
+    cosines[:, 1] = ydiff / s
+    cosines[:, 2] = zdiff / s
+    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    return ray
+
+
+
+ +
+
+
+

In the future, we must find out where a ray lands after a certain amount of propagation distance for various purposes, which we will describe in this chapter. +For that purpose, let us also create a utility function that propagates a ray to some distance, \(d\), using \(x_0, y_0, z_0\) and \(cos(\theta_x), cos(\theta_y), cos(\theta_z)\):

+
\[ +x_{new} = x_0 + cos(\theta_x) d,\\ +y_{new} = y_0 + cos(\theta_y) d,\\ +z_{new} = z_0 + cos(\theta_z) d. +\]
+

Let us also check the function provided below to understand its source code, parameters, and returns. +This function will serve as a utility function to propagate a ray or a batch of rays in our future simulations.

+
+
+
+ + +
+ + + + +
+ +

Definition to propagate a ray at a certain given distance.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
    +
    +
    +
  • +
  • + distance + – +
    +
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_ray ( tensor +) – +
    +

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
def propagate_ray(ray, distance):
+    """
+    Definition to propagate a ray at a certain given distance.
+
+    Parameters
+    ----------
+    ray        : torch.tensor
+                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
+    distance   : torch.tensor
+                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
+
+    Returns
+    ----------
+    new_ray    : torch.tensor
+                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(distance.shape) == 2:
+        distance = distance.squeeze(-1)
+    new_ray = torch.zeros_like(ray)
+    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]
+    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]
+    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]
+    return new_ray
+
+
+
+ +
+
+
+

It is now time for us to put what we have learned so far into an actual code. +We can create many rays using the two functions, odak.learn.raytracing.create_ray_from_two_points and odak.learn.raytracing.create_ray. +However, to do so, we need to have many points in both cases. +For that purpose, let's carefully review this utility function provided below. +This utility function can generate grid samples from a plane with some tilt, and we can also define the center of our samples to position points anywhere in Euclidean space.

+
+
+
+ + +
+ + + + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( tensor +) – +
    +

    Samples generated.

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix at X axis.

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix at Y axis.

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix at Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/sample.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
def grid_sample(
+                no = [10, 10],
+                size = [100., 100.], 
+                center = [0., 0., 0.], 
+                angles = [0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    -------
+    samples     : torch.tensor
+                  Samples generated.
+    rotx        : torch.tensor
+                  Rotation matrix at X axis.
+    roty        : torch.tensor
+                  Rotation matrix at Y axis.
+    rotz        : torch.tensor
+                  Rotation matrix at Z axis.
+    """
+    center = torch.tensor(center)
+    angles = torch.tensor(angles)
+    size = torch.tensor(size)
+    samples = torch.zeros((no[0], no[1], 3))
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))
+    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)
+    return samples, rotx, roty, rotz
+
+
+
+ +
+
+
+

The below script provides a sample use case for the functions provided above. +I also leave comments near some lines explaining the code in steps.

+
+
+
+
import sys
+import odak
+import torch # (1)
+
+
+def test(directory = 'test_output'):
+    odak.tools.check_directory(directory)
+    starting_point = torch.tensor([[5., 5., 0.]]) # (2)
+    end_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                       no = [2, 2], 
+                                                       size = [20., 20.], 
+                                                       center = [0., 0., 10.]
+                                                      ) # (3)
+    rays_from_points = odak.learn.raytracing.create_ray_from_two_points(
+                                                                        starting_point,
+                                                                        end_points
+                                                                       ) # (4)
+
+
+    starting_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                            no = [3, 3], 
+                                                            size = [100., 100.], 
+                                                            center = [0., 0., 10.],
+                                                           )
+    angles = torch.randn_like(starting_points) * 180. # (5)
+    rays_from_angles = odak.learn.raytracing.create_ray(
+                                                        starting_points,
+                                                        angles
+                                                       ) # (6)
+
+
+    distances = torch.ones(rays_from_points.shape[0]) * 12.5
+    propagated_rays = odak.learn.raytracing.propagate_ray(
+                                                          rays_from_points,
+                                                          distances
+                                                         ) # (7)
+
+
+
+
+    visualize = False # (8)
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.)
+        ray_diagram.add_point(starting_point, color = 'red')
+        ray_diagram.add_point(end_points[0], color = 'blue')
+        ray_diagram.add_line(starting_point, end_points[0], color = 'green')
+        x_axis = starting_point.clone()
+        x_axis[0, 0] = end_points[0, 0]
+        ray_diagram.add_point(x_axis, color = 'black')
+        ray_diagram.add_line(starting_point, x_axis, color = 'black', dash = 'dash')
+        y_axis = starting_point.clone()
+        y_axis[0, 1] = end_points[0, 1]
+        ray_diagram.add_point(y_axis, color = 'black')
+        ray_diagram.add_line(starting_point, y_axis, color = 'black', dash = 'dash')
+        z_axis = starting_point.clone()
+        z_axis[0, 2] = end_points[0, 2]
+        ray_diagram.add_point(z_axis, color = 'black')
+        ray_diagram.add_line(starting_point, z_axis, color = 'black', dash = 'dash')
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
    +
  1. Required libraries are imported.
  2. +
  3. Defining a starting point, in order X, Y and Z locations. + Size of starting point could be s1] or [1, 1].
  4. +
  5. Defining some end points on a plane in grid fashion.
  6. +
  7. odak.learn.raytracing.create_ray_from_two_points is verified with an example! Let's move on to odak.learn.raytracing.create_ray.
  8. +
  9. Creating starting points with odak.learn.tools.grid_sample and defining some angles as the direction using torch.randn. + Note that the angles are in degrees.
  10. +
  11. odak.learn.raytracing.create_ray is verified with an example!
  12. +
  13. odak.learn.raytracing.propagate_a_ray is verified with an example!
  14. +
  15. Set it to True to enable visualization.
  16. +
+
+
+
+

The above code also has parts that are disabled (see visualize variable). +We disabled these lines intentionally to avoid running it at every run. +Let me talk about these disabled functions as well. +Odak offers a tidy approach to simple visualizations through packages called Plotly and kaleido. +To make these lines work by setting visualize = True, you must first install plotly in your work environment. +This installation is as simple as pip3 install plotly kaleido in a Linux system. +As you install these packages and enable these lines, the code will produce a visualization similar to the one below. +Note that this is an interactive visualization where you can interact with your mouse clicks to rotate, shift, and zoom. +In this visualization, we visualize a single ray (green line) starting from our defined starting point (red dot) and ending at one of the end_points (blue dot). +We also highlight three axes with black lines to provide a reference frame. +Although odak.visualize.plotly offers us methods to visualize rays quickly for debugging, it is highly suggested to stick to a low number of lines when using it (e.g., say not exceeding 100 rays in total). +The proper way to draw many rays lies in modern path-tracing renderers such as Blender.

+
+
+ +
+How can I learn more about more sophisticated renderers like Blender? +

Blender is a widely used open-source renderer that comes with sophisticated features. +It is user interface could be challenging for newcomers. +A blog post published by SIGGRAPH Research Career Development Committee offers a neat entry-level post titled Rendering a paper figure with Blender written by Silvia Sellán.

+

In addition to Blender, there are various renderers you may be happy to know about if you are curious about Computer Graphics. +Mitsuba 3 is another sophisticated rendering system based on a SIGGRAPH paper titled Dr.Jit: A Just-In-Time Compiler for Differentiable Rendering 4 from Wenzel Jakob.

+

If you know any other, please share it with the class so that they also learn more about other renderers.

+
+
+Challenge: Blender meets Odak +

In light of the given information, we challenge readers to create a new submodule for Odak. +Note that Odak has odak.visualize.blender submodule. +However, at the time of this writing, this submodule works as a server that sends commands to a program that has to be manually triggered inside Blender. +Odak seeks an upgrade to this submodule, where users can draw rays, meshes, or parametric surfaces easily in Blender with commands from Odak. +This newly upgraded submodule should require no manual processes. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for your new submodule in docs/notes/odak_meets_blender.md.

+
+

Intersecting rays with surfaces

+

Informative · + Practical

+

Rays we have described so far help us explore light and matter interactions. +Often in simulations, these rays interact with surfaces. +In a simulation environment for optical design, equations often describe surfaces continuously. +These surface equations typically contain a number of parameters for defining surfaces. +For example, let us consider a sphere, which follows a standard equation as follows,

+
\[ +r^2 = (x - x_0)^2 + (y - y_0)^2 + (z - z_0)^2, +\]
+

Where \(r\) represents the diameter of that sphere, \(x_0, y_0, z_0\) defines the center location of that sphere, and \(x, y, z\) are points on the surface of a sphere. +When testing if a point is on a sphere, we use the above equation by inserting the point to be tested as \(x, y, z\) into that equation. +In other words, to find a ray and sphere intersection, we must identify a distance that propagates our rays a certain amount and lends on a point on that sphere, and we can use the above sphere equation for identifying the intersection point of that rays. +As long the surface equation is well degined, the same strategy can be used for any surfaces. +In addition, if needed for future purposes (e.g., reflecting or refracting light off the surface of that sphere), we can also calculate the surface normal of that sphere by drawing a line by defining a ray starting from the center of that sphere and propagating towards the intersection point. +Let us examine, how we can identify intersection points for a set of given rays and a sphere by examining the below function.

+
+
+
+ + +
+ + + + +
+ +

Definition to find the intersection between ray(s) and sphere(s).

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  Input ray(s).
    +              Expected size is [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + sphere + – +
    +
                  Input sphere.
    +              Expected size is [1 x 4].
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                  Learning rate used in the optimizer for finding the propagation distances of the rays.
    +
    +
    +
  • +
  • + number_of_steps + – +
    +
                  Number of steps used in the optimizer.
    +
    +
    +
  • +
  • + error_threshold + – +
    +
                  The error threshold that will help deciding intersection or no intersection.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +intersecting_ray ( tensor +) – +
    +

    Ray(s) that intersecting with the given sphere. +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normal(s) for the ray(s) intersecting with the given sphere +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):
+    """
+    Definition to find the intersection between ray(s) and sphere(s).
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          Input ray(s).
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    sphere              : torch.tensor
+                          Input sphere.
+                          Expected size is [1 x 4].
+    learning_rate       : float
+                          Learning rate used in the optimizer for finding the propagation distances of the rays.
+    number_of_steps     : int
+                          Number of steps used in the optimizer.
+    error_threshold     : float
+                          The error threshold that will help deciding intersection or no intersection.
+
+    Returns
+    -------
+    intersecting_ray    : torch.tensor
+                          Ray(s) that intersecting with the given sphere.
+                          Expected size is [n x 2 x 3], where n could be any real number.
+    intersecting_normal : torch.tensor
+                          Normal(s) for the ray(s) intersecting with the given sphere
+                          Expected size is [n x 2 x 3], where n could be any real number.
+
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(sphere.shape) == 1:
+        sphere = sphere.unsqueeze(0)
+    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)
+    loss_l2 = torch.nn.MSELoss(reduction = 'sum')
+    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    
+    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)
+    for step in t:
+        optimizer.zero_grad()
+        propagated_ray = propagate_ray(ray, distance)
+        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)
+        loss = loss_l2(
+                       test,
+                       torch.zeros_like(test)
+                      )
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description('Sphere intersection loss: {}'.format(loss.item()))
+    check = test < error_threshold
+    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])
+    intersecting_normal = create_ray_from_two_points(
+                                                     sphere[:, 0:3],
+                                                     intersecting_ray[:, 0]
+                                                    )
+    return intersecting_ray, intersecting_normal, distance, check
+
+
+
+ +
+
+
+

The odak.learn.raytracing.intersect_w_sphere function uses an optimizer to identify intersection points for each ray. +Instead, a function could have accomplished the task with a closed-form solution without iterating over the intersection test, which could have been much faster than the current function. +If you are curious about how to fix the highlighted issue, you may want to see the challenge provided below.

+

Let us examine how we can use the provided sphere intersection function with an example provided at the end of this subsection.

+
+
+
+
import sys
+import odak
+import torch
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    starting_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                            no = [5, 5],
+                                                            size = [3., 3.],
+                                                            center = [0., 0., 0.]
+                                                           )
+    end_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                       no = [5, 5],
+                                                       size = [0.1, 0.1],
+                                                       center = [0., 0., 5.]
+                                                      )
+    rays = odak.learn.raytracing.create_ray_from_two_points(
+                                                            starting_points,
+                                                            end_points
+                                                           )
+    center = torch.tensor([[0., 0., 5.]])
+    radius = torch.tensor([[3.]])
+    sphere = odak.learn.raytracing.define_sphere(
+                                                 center = center,
+                                                 radius = radius
+                                                ) # (1)
+    intersecting_rays, intersecting_normals, _, check = odak.learn.raytracing.intersect_w_sphere(rays, sphere)
+
+
+    visualize = False # (2)
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.)
+        ray_diagram.add_point(rays[:, 0], color = 'blue')
+        ray_diagram.add_line(rays[:, 0][check == True], intersecting_rays[:, 0], color = 'blue')
+        ray_diagram.add_sphere(sphere, color = 'orange')
+        ray_diagram.add_point(intersecting_normals[:, 0], color = 'green')
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
    +
  1. Here we provide an example use case for odak.learn.raytracing.intersect_w_sphere by providing a sphere and a batch of sample rays.
  2. +
  3. Uncomment for running visualization.
  4. +
+
+
+
+
+

Image title +

+
Screenshow showing a sphere and ray intersections generated by "test_learn_ray_intersect_w_a_sphere.py" script.
+
+

This section shows us how to operate with known geometric shapes, precisely spheres. +However, not every shape could be defined using parametric modeling (e.g., nonlinearities such as discontinuities on a surface). +We will look into another method in the next section, an approach used by folks working in Computer Graphics.

+
+Challenge: Raytracing arbitrary surfaces +

In light of the given information, we challenge readers to create a new function inside odak.learn.raytracing submodule that replaces the current intersect_w_sphere function. +In addition, the current unit test test/test_learn_ray_intersect_w_a_sphere.py has to adopt this new function. +odak.learn.raytracing submodule also needs new functions for supporting arbitrary surfaces (parametric). +New unit tests are needed to improve the submodule accordingly. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_arbitrary_surfaces.md.

+
+

Intersecting rays with meshes

+

Informative · + Practical

+

Parametric surfaces provide ease in defining shapes and geometries in various fields, including Optics and Computer Graphics. +However, not every object in a given scene could easily be described using parametric surfaces. +In many cases, including modern Computer Graphics, triangles formulate the smallest particle of an object or a shape. +These triangles altogether form meshes that define objects and shapes. +For this purpose, we will review source codes, parameters, and returns of three utility functions here. +We will first review odak.learn.raytracing.intersect_w_surface to understand how one can calculate the intersection of a ray with a given plane. +Later, we review odak.learn.raytracing.is_it_on_triangle function, which checks if an intersection point on a given surface is inside a triangle on that surface. +Finally, we will review odak.learn.raytracing.intersect_w_triangle function. +This last function combines both reviewed functions into a single function to identify the intersection between rays and a triangle.

+
+
+
+ + +
+ + + + +
+ +

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + points + – +
    +
           Set of points in X,Y and Z to define a planar surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_surface(ray, points):
+    """
+    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray.
+    points       : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface.
+
+    Returns
+    ----------
+    normal       : torch.tensor
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between starting point of a ray with it's intersection with a planar surface.
+    """
+    normal = get_triangle_normal(points)
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(points.shape) == 2:
+        points = points.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    f = normal[:, 0] - ray[:, 0]
+    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T
+    new_normal = torch.zeros_like(ray)
+    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]
+    new_normal[:, 1] = normal[:, 1]
+    new_normal = torch.nan_to_num(
+                                  new_normal,
+                                  nan = float('nan'),
+                                  posinf = float('nan'),
+                                  neginf = float('nan')
+                                 )
+    distance = torch.nan_to_num(
+                                distance,
+                                nan = float('nan'),
+                                posinf = float('nan'),
+                                neginf = float('nan')
+                               )
+    return new_normal, distance
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definition to check if a given point is inside a triangle. +If the given point is inside a defined triangle, this definition returns True. +For more details, visit: https://blackpawn.com/texts/pointinpoly/.

+ + +

Parameters:

+
    +
  • + point_to_check + – +
    +
              Point(s) to check.
    +          Expected size is [3], [1 x 3] or [m x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
              Triangle described with three points.
    +          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Is it on a triangle? Returns NaN if condition not satisfied. +Expected size is [1] or [m] depending on the input.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def is_it_on_triangle(point_to_check, triangle):
+    """
+    Definition to check if a given point is inside a triangle. 
+    If the given point is inside a defined triangle, this definition returns True.
+    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).
+
+    Parameters
+    ----------
+    point_to_check  : torch.tensor
+                      Point(s) to check.
+                      Expected size is [3], [1 x 3] or [m x 3].
+    triangle        : torch.tensor
+                      Triangle described with three points.
+                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
+
+    Returns
+    -------
+    result          : torch.tensor
+                      Is it on a triangle? Returns NaN if condition not satisfied.
+                      Expected size is [1] or [m] depending on the input.
+    """
+    if len(point_to_check.shape) == 1:
+        point_to_check = point_to_check.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    v0 = triangle[:, 2] - triangle[:, 0]
+    v1 = triangle[:, 1] - triangle[:, 0]
+    v2 = point_to_check - triangle[:, 0]
+    if len(v0.shape) == 1:
+        v0 = v0.unsqueeze(0)
+    if len(v1.shape) == 1:
+        v1 = v1.unsqueeze(0)
+    if len(v2.shape) == 1:
+        v2 = v2.unsqueeze(0)
+    dot00 = torch.mm(v0, v0.T)
+    dot01 = torch.mm(v0, v1.T)
+    dot02 = torch.mm(v0, v2.T) 
+    dot11 = torch.mm(v1, v1.T)
+    dot12 = torch.mm(v1, v2.T)
+    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)
+    u = (dot11 * dot02 - dot01 * dot12) * invDenom
+    v = (dot00 * dot12 - dot01 * dot02) * invDenom
+    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)
+    return result
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definition to find intersection point of a ray with a triangle.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection with the surface of triangle. +This could also involve surface normals that are not on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle. +Expected size is [1 x 1] or [m x 1] depending on the input.

    +
    +
  • +
  • +intersecting_ray ( tensor +) – +
    +

    Rays that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normals that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +check ( tensor +) – +
    +

    A list that provides a bool as True or False for each ray used as input. +A test to see is a ray could be on the given triangle. +Expected size is [1] or [m].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_triangle(ray, triangle):
+    """
+    Definition to find intersection point of a ray with a triangle. 
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
+    triangle            : torch.tensor
+                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
+
+    Returns
+    ----------
+    normal              : torch.tensor
+                          Surface normal at the point of intersection with the surface of triangle.
+                          This could also involve surface normals that are not on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    distance            : float
+                          Distance in between a starting point of a ray and the intersection point with a given triangle.
+                          Expected size is [1 x 1] or [m x 1] depending on the input.
+    intersecting_ray    : torch.tensor
+                          Rays that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    intersecting_normal : torch.tensor
+                          Normals that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    check               : torch.tensor
+                          A list that provides a bool as True or False for each ray used as input.
+                          A test to see is a ray could be on the given triangle.
+                          Expected size is [1] or [m].
+    """
+    if len(triangle.shape) == 2:
+       triangle = triangle.unsqueeze(0)
+    if len(ray.shape) == 2:
+       ray = ray.unsqueeze(0)
+    normal, distance = intersect_w_surface(ray, triangle)
+    check = is_it_on_triangle(normal[:, 0], triangle)
+    intersecting_ray = ray.unsqueeze(0)
+    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_ray = intersecting_ray[check == True]
+    intersecting_normal = normal.unsqueeze(0)
+    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_normal = intersecting_normal[check ==  True]
+    return normal, distance, intersecting_ray, intersecting_normal, check
+
+
+
+ +
+
+
+

Using the provided utility functions above, let us build an example below that helps us find intersections between a triangle and a batch of rays.

+
+
+
+
import sys
+import odak
+import torch
+
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    starting_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                            no = [5, 5],
+                                                            size = [10., 10.],
+                                                            center = [0., 0., 0.]
+                                                           )
+    end_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                       no = [5, 5],
+                                                       size = [6., 6.],
+                                                       center = [0., 0., 10.]
+                                                      )
+    rays = odak.learn.raytracing.create_ray_from_two_points(
+                                                            starting_points,
+                                                            end_points
+                                                           )
+    triangle = torch.tensor([[
+                              [-5., -5., 10.],
+                              [ 5., -5., 10.],
+                              [ 0.,  5., 10.]
+                            ]])
+    normals, distance, _, _, check = odak.learn.raytracing.intersect_w_triangle(
+                                                                                rays,
+                                                                                triangle
+                                                                               ) # (2)
+
+
+
+    visualize = False # (1)
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.) # (1)
+        ray_diagram.add_triangle(triangle, color = 'orange')
+        ray_diagram.add_point(rays[:, 0], color = 'blue')
+        ray_diagram.add_line(rays[:, 0], normals[:, 0], color = 'blue')
+        colors = []
+        for color_id in range(check.shape[1]):
+            if check[0, color_id] == True:
+                colors.append('green')
+            elif check[0, color_id] == False:
+                colors.append('red')
+        ray_diagram.add_point(normals[:, 0], color = colors)
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
    +
  1. Uncomment for running visualization.
  2. +
  3. Returning intersection normals as new rays, distances from starting point of input rays and a check which returns True if intersection points are inside the triangle.
  4. +
+
+
+
+
+
+ +
+Why should we be interested in ray and triangle intersections? +

Modern Computer Graphics uses various representations for defining three-dimensional objects and scenes. +These representations include: +* Point Clouds: a series of XYZ coordinates from the surface of a three-dimensional object, +* Meshes: a soup of triangles that represents a surface of a three-dimensional object, +* Signed Distance Functions: a function informing about the distance between an XYZ point and a surface of a three-dimensional object, +* Neural Radiance Fields: A machine learning approach to learning ray patterns from various perspectives. +Historically, meshes have been mainly used to represent three-dimensional objects. +Thus, intersecting rays and triangles are important for most Computer Graphics.

+
+
+Challenge: Many triangles! +

The example provided above deals with a ray and a batch of rays. +However, objects represented with triangles are typically described with many triangles but not one. +Note that odak.learn.raytracing.intersect_w_triangle deal with each triangle one by one, and may lead to slow execution times as the function has to visit each triangle one by one. +Given the information, we challenge readers to create a new function inside odak.learn.raytracing submodule named intersect_w_mesh. +This new function has to be able to work with multiple triangles (meshes) and has to be aware of "occlusions" (e.g., a triangle blocking another triangle). +In addition, a new unit test, test/test_learn_ray_intersect_w_mesh.py, has to adopt this new function. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_meshes.md.

+
+

Refracting and reflecting rays

+

Informative · + Practical

+

In the previous subsections, we reviewed ray intersection with various surface representations, including parametric (e.g., spheres) and non-parametric (e.g., meshes). +Please remember that raytracing is the most simplistic modeling of light. +Thus, often raytracing does not account for any wave or quantum-related nature of light. +To our knowledge, light refracts, reflects, or diffracts when light interfaces with a surface or, in other words, a changing medium (e.g., light traveling from air to glass). +In that case, our next step should be identifying a methodology to help us model these events using rays. +We compiled two utility functions that could help us to model a refraction or a reflection. +These functions are named odak.learn.raytracing.refract 1 and odak.learn.raytracing.reflect 1. +This first one, odak.learn.raytracing.refract follows Snell's law of refraction, while odak.learn.raytracing.reflect follows a perfect reflection case. +We will not go into details of this theory as its simplest form in the way we discuss it could now be considered common knowledge. +However, for curious readers, the work by Bell et al. 5 provides a generalized solution for the laws of refraction and reflection. +Let us carefully examine these two utility functions to understand their internal workings.

+
+
+
+ + +
+ + + + +
+ +

Definition to refract an incoming ray. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + vector + – +
    +
             Incoming ray.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
    +
    +
    +
  • +
  • + normvector + – +
    +
             Normal vector.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
    +
    +
    +
  • +
  • + n1 + – +
    +
             Refractive index of the incoming medium.
    +
    +
    +
  • +
  • + n2 + – +
    +
             Refractive index of the outgoing medium.
    +
    +
    +
  • +
  • + error + – +
    +
             Desired error.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Refracted ray. +Expected size is [1, 2, 3]

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
def refract(vector, normvector, n1, n2, error = 0.01):
+    """
+    Definition to refract an incoming ray.
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    vector         : torch.tensor
+                     Incoming ray.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
+    normvector     : torch.tensor
+                     Normal vector.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
+    n1             : float
+                     Refractive index of the incoming medium.
+    n2             : float
+                     Refractive index of the outgoing medium.
+    error          : float 
+                     Desired error.
+
+    Returns
+    -------
+    output         : torch.tensor
+                     Refracted ray.
+                     Expected size is [1, 2, 3]
+    """
+    if len(vector.shape) == 2:
+        vector = vector.unsqueeze(0)
+    if len(normvector.shape) == 2:
+        normvector = normvector.unsqueeze(0)
+    mu    = n1 / n2
+    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2
+    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div
+    b     = (mu ** 2 - 1) / div
+    to    = - b * 0.5 / a
+    num   = 0
+    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2
+    while len(eps[eps > error]) > 0:
+       num   += 1
+       oldto  = to
+       v      = to ** 2 + 2 * a * to + b
+       deltav = 2 * (to + a)
+       to     = to - v / deltav
+       eps    = abs(oldto - to)
+    output = torch.zeros_like(vector)
+    output[:, 0, 0] = normvector[:, 0, 0]
+    output[:, 0, 1] = normvector[:, 0, 1]
+    output[:, 0, 2] = normvector[:, 0, 2]
+    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]
+    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]
+    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]
+    return output
+
+
+
+ +
+
+ + +
+ + + + +
+ +

Definition to reflect an incoming ray from a surface defined by a surface normal. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + input_ray + – +
    +
           A ray or rays.
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + normal + – +
    +
           A surface normal(s).
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a reflected ray. +Expected size is [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def reflect(input_ray, normal):
+    """ 
+    Definition to reflect an incoming ray from a surface defined by a surface normal. 
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    input_ray    : torch.tensor
+                   A ray or rays.
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+    normal       : torch.tensor
+                   A surface normal(s).
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    Returns
+    ----------
+    output_ray   : torch.tensor
+                   Array that contains starting points and cosines of a reflected ray.
+                   Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    """
+    if len(input_ray.shape) == 2:
+        input_ray = input_ray.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    mu = 1
+    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8
+    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div
+    a = a.unsqueeze(1)
+    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))
+    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)
+    output_ray[:, 0] = normal[:, 0]
+    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]
+    return output_ray
+
+
+
+ +
+
+
+

Please note that we provide two refractive indices as inputs in odak.learn.raytracing.refract. +These inputs represent the refractive indices of two mediums (e.g., air and glass). +However, the refractive index of a medium is dependent on light's wavelength (color). +In the following example, where we showcase a sample use case of these utility functions, we will assume that light has a single wavelength. +But bear in mind that when you need to ray trace with lots of wavelengths (multi-color RGB or hyperspectral), one must ray trace for each wavelength (color). +Thus, the computational complexity of the raytracing increases dramatically as we aim growing realism in the simulations (e.g., describe scenes per color, raytracing for each color). +Let's dive deep into how we use these functions in an actual example by observing the example below.

+
+
+
+
import sys
+import odak
+import torch
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    starting_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                            no = [5, 5],
+                                                            size = [15., 15.],
+                                                            center = [0., 0., 0.]
+                                                           )
+    end_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                       no = [5, 5],
+                                                       size = [6., 6.],
+                                                       center = [0., 0., 10.]
+                                                      )
+    rays = odak.learn.raytracing.create_ray_from_two_points(
+                                                            starting_points,
+                                                            end_points
+                                                           )
+    triangle = torch.tensor([[
+                              [-5., -5., 10.],
+                              [ 5., -5., 10.],
+                              [ 0.,  5., 10.]
+                            ]])
+    normals, distance, intersecting_rays, intersecting_normals, check = odak.learn.raytracing.intersect_w_triangle(
+                                                                                    rays,
+                                                                                    triangle
+                                                                                   ) 
+    n_air = 1.0 # (1)
+    n_glass = 1.51 # (2)
+    refracted_rays = odak.learn.raytracing.refract(intersecting_rays, intersecting_normals, n_air, n_glass) # (3)
+    reflected_rays = odak.learn.raytracing.reflect(intersecting_rays, intersecting_normals) # (4)
+    refract_distance = 11.
+    reflect_distance = 7.2
+    propagated_refracted_rays = odak.learn.raytracing.propagate_ray(
+                                                                    refracted_rays, 
+                                                                    torch.ones(refracted_rays.shape[0]) * refract_distance
+                                                                   )
+    propagated_reflected_rays = odak.learn.raytracing.propagate_ray(
+                                                                    reflected_rays,
+                                                                    torch.ones(reflected_rays.shape[0]) * reflect_distance
+                                                                   )
+
+
+
+    visualize = False
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(
+                                                    columns = 2,
+                                                    line_width = 3.,
+                                                    marker_size = 3.,
+                                                    subplot_titles = ['Refraction example', 'Reflection example']
+                                                   ) # (1)
+        ray_diagram.add_triangle(triangle, column = 1, color = 'orange')
+        ray_diagram.add_triangle(triangle, column = 2, color = 'orange')
+        ray_diagram.add_point(rays[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_point(rays[:, 0], column = 2, color = 'blue')
+        ray_diagram.add_line(rays[:, 0], normals[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_line(rays[:, 0], normals[:, 0], column = 2, color = 'blue')
+        ray_diagram.add_line(refracted_rays[:, 0], propagated_refracted_rays[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_line(reflected_rays[:, 0], propagated_reflected_rays[:, 0], column = 2, color = 'blue')
+        colors = []
+        for color_id in range(check.shape[1]):
+            if check[0, color_id] == True:
+                colors.append('green')
+            elif check[0, color_id] == False:
+                colors.append('red')
+        ray_diagram.add_point(normals[:, 0], column = 1, color = colors)
+        ray_diagram.add_point(normals[:, 0], column = 2, color = colors)
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
    +
  1. Refractive index of air (arbitrary and regardless of wavelength), the medium before the ray and triangle intersection.
  2. +
  3. Refractive index of glass (arbitrary and regardless of wavelength), the medium after the ray and triangle intersection.
  4. +
  5. Refraction process.
  6. +
  7. Reflection process.
  8. +
+
+
+
+
+
+ +
+Challenge: Diffracted intefering rays +

This subsection covered simulating refraction and reflection events. +However, diffraction or interference 6 is not introduced in this raytracing model. +This is because diffraction and interference would require another layer of complication. +In other words, rays have to have an extra dimension beyond their starting points and direction cosines, and they also have to have the quality named phase of light. +This fact makes a typical ray have dimensions of [1 x 3 x 3] instead of [1 x 2 x 3], where only direction cosines and starting points are defined. +Given the information, we challenge readers to create a new submodule, odak.learn.raytracing.diffraction, extending rays to diffraction and interference. +In addition, a new set of unit tests should be derived to adopt this new function submodule. +To add these to odak, you can rely on the pull request feature on GitHub. +You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_diffraction_interference.md.

+
+

Optimization with rays

+

Informative · + Practical

+

We learned about refraction, reflection, rays, and surface intersections in the previous subsection. +We didn't mention it then, but these functions are differentiable 7. +In other words, a modern machine learning library can keep a graph of a variable passing through each one of these functions (see chain rule). +This differentiability feature is vital because differentiability makes our simulations for light with raytracing based on these functions compatible with modern machine learning frameworks such as Torch. +In this subsection, we will use an off-the-shelf optimizer from Torch to optimize variables in our ray tracing simulations. +In the first example, we will see that the optimizer helps us define the proper tilt angles for a triangle-shaped mirror and redirect light from a point light source towards a given target. +Our first example resembles a straightforward case for optimization by containing only a batch of rays and a single triangle. +The problem highlighted in the first example has a closed-form solution, and using an optimizer is obviously overkill. +We want our readers to understand that the first example is a warm-up scenario where our readers understand how to interact with race and triangles in the context of an optimization problem. +In our following second example, we will deal with a more sophisticated case where a batch of rays arriving from a point light source bounces off a surface with multiple triangles in parentheses mesh and comes at some point in our final target plane. +This time we will ask our optimizer to optimize the shape of our triangles so that most of the light bouncing off there's optimized surface ends up at a location close to a target we define in our simulation. +This way, we show our readers that a more sophisticated shape could be optimized using our framework, Odak. +In real life, the second example could be a lens or mirror shape to be optimized. +More specifically, as an application example, it could be a mirror or a lens that focuses light from the Sun onto a solar cell to increase the efficiency of a solar power system, or it could have been a lens helping you to focus on a specific depth given your eye prescription. +Let us start from our first example and examine how we can tilt the surfaces using an optimizer, and in this second example, let us see how an optimizer helps us define and optimize shape for a given mesh.

+
+
+
+
import sys
+import odak
+import torch
+from tqdm import tqdm
+
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    final_surface = torch.tensor([[
+                                   [-5., -5., 0.],
+                                   [ 5., -5., 0.],
+                                   [ 0.,  5., 0.]
+                                 ]])
+    final_target = torch.tensor([[3., 3., 0.]])
+    triangle = torch.tensor([
+                             [-5., -5., 10.],
+                             [ 5., -5., 10.],
+                             [ 0.,  5., 10.]
+                            ])
+    starting_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                            no = [5, 5],
+                                                            size = [1., 1.],
+                                                            center = [0., 0., 0.]
+                                                           )
+    end_point = odak.learn.raytracing.center_of_triangle(triangle)
+    rays = odak.learn.raytracing.create_ray_from_two_points(
+                                                            starting_points,
+                                                            end_point
+                                                           )
+    angles = torch.zeros(1, 3, requires_grad = True)
+    learning_rate = 2e-1
+    optimizer = torch.optim.Adam([angles], lr = learning_rate)
+    loss_function = torch.nn.MSELoss()
+    number_of_steps = 100
+    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)
+    for step in t:
+        optimizer.zero_grad()
+        rotated_triangle, _, _, _ = odak.learn.tools.rotate_points(
+                                                                   triangle, 
+                                                                   angles = angles, 
+                                                                   origin = end_point
+                                                                  )
+        _, _, intersecting_rays, intersecting_normals, check = odak.learn.raytracing.intersect_w_triangle(
+                                                                                                          rays,
+                                                                                                          rotated_triangle
+                                                                                                         )
+        reflected_rays = odak.learn.raytracing.reflect(intersecting_rays, intersecting_normals)
+        final_normals, _ = odak.learn.raytracing.intersect_w_surface(reflected_rays, final_surface)
+        if step == 0:
+            start_rays = rays.detach().clone()
+            start_rotated_triangle = rotated_triangle.detach().clone()
+            start_intersecting_rays = intersecting_rays.detach().clone()
+            start_intersecting_normals = intersecting_normals.detach().clone()
+            start_final_normals = final_normals.detach().clone()
+        final_points = final_normals[:, 0]
+        target = final_target.repeat(final_points.shape[0], 1)
+        loss = loss_function(final_points, target)
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description('Loss: {}'.format(loss.item()))
+    print('Loss: {}, angles: {}'.format(loss.item(), angles))
+
+
+    visualize = False
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(
+                                                    columns = 2,
+                                                    line_width = 3.,
+                                                    marker_size = 3.,
+                                                    subplot_titles = [
+                                                                       'Surace before optimization', 
+                                                                       'Surface after optimization',
+                                                                       'Hits at the target plane before optimization',
+                                                                       'Hits at the target plane after optimization',
+                                                                     ]
+                                                   ) 
+        ray_diagram.add_triangle(start_rotated_triangle, column = 1, color = 'orange')
+        ray_diagram.add_triangle(rotated_triangle, column = 2, color = 'orange')
+        ray_diagram.add_point(start_rays[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_point(rays[:, 0], column = 2, color = 'blue')
+        ray_diagram.add_line(start_intersecting_rays[:, 0], start_intersecting_normals[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_line(intersecting_rays[:, 0], intersecting_normals[:, 0], column = 2, color = 'blue')
+        ray_diagram.add_line(start_intersecting_normals[:, 0], start_final_normals[:, 0], column = 1, color = 'blue')
+        ray_diagram.add_line(start_intersecting_normals[:, 0], final_normals[:, 0], column = 2, color = 'blue')
+        ray_diagram.add_point(final_target, column = 1, color = 'red')
+        ray_diagram.add_point(final_target, column = 2, color = 'green')
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
+
+
+
+
+ +

Let us also look into the more sophisticated second example, where a triangular mesh is optimized to meet a specific demand, redirecting rays to a particular target.

+
+
+
+
import sys
+import odak
+import torch
+from tqdm import tqdm
+
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    device = torch.device('cpu')
+    final_target = torch.tensor([-2., -2., 10.], device = device)
+    final_surface = odak.learn.raytracing.define_plane(point = final_target)
+    mesh = odak.learn.raytracing.planar_mesh(
+                                             size = torch.tensor([1.1, 1.1]), 
+                                             number_of_meshes = torch.tensor([9, 9]), 
+                                             device = device
+                                            )
+    start_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                         no = [11, 11],
+                                                         size = [1., 1.],
+                                                         center = [2., 2., 10.]
+                                                        )
+    end_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                       no = [11, 11],
+                                                       size = [1., 1.],
+                                                       center = [0., 0., 0.]
+                                                      )
+    start_points = start_points.to(device)
+    end_points = end_points.to(device)
+    loss_function = torch.nn.MSELoss(reduction = 'sum')
+    learning_rate = 2e-3
+    optimizer = torch.optim.AdamW([mesh.heights], lr = learning_rate)
+    rays = odak.learn.raytracing.create_ray_from_two_points(start_points, end_points)
+    number_of_steps = 100
+    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)
+    for step in t:
+        optimizer.zero_grad()
+        triangles = mesh.get_triangles()
+        reflected_rays, reflected_normals = mesh.mirror(rays)
+        final_normals, _ = odak.learn.raytracing.intersect_w_surface(reflected_rays, final_surface)
+        final_points = final_normals[:, 0]
+        target = final_target.repeat(final_points.shape[0], 1)
+        if step == 0:
+            start_triangles = triangles.detach().clone()
+            start_reflected_rays = reflected_rays.detach().clone()
+            start_final_normals = final_normals.detach().clone()
+        loss = loss_function(final_points, target)
+        loss.backward(retain_graph = True)
+        optimizer.step() 
+        description = 'Loss: {}'.format(loss.item())
+        t.set_description(description)
+    print(description)
+
+
+    visualize = False
+    if visualize:
+        ray_diagram = odak.visualize.plotly.rayshow(
+                                                    rows = 1,
+                                                    columns = 2,
+                                                    line_width = 3.,
+                                                    marker_size = 1.,
+                                                    subplot_titles = ['Before optimization', 'After optimization']
+                                                   ) 
+        for triangle_id in range(triangles.shape[0]):
+            ray_diagram.add_triangle(
+                                     start_triangles[triangle_id], 
+                                     row = 1, 
+                                     column = 1, 
+                                     color = 'orange'
+                                    )
+            ray_diagram.add_triangle(triangles[triangle_id], row = 1, column = 2, color = 'orange')
+        html = ray_diagram.save_offline()
+        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')
+        markdown_file.write(html)
+        markdown_file.close()
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
+
+
+
+ +
+Challenge: Differentiable detector +

In our examples, where we try bouncing light towards a fixed target, our target is defined as a single point along XYZ axes. +However, in many cases in Optics and Computer Graphics, we may want to design surfaces to resemble a specific distribution of intensities over a plane (e.g., a detector or a camera sensor). +For example, the work by Schwartzburg et al. 8 designs optical surfaces such that when light refracts, the distribution of these intensities forms an image at some target plane. +To be able to replicate such works with Odak, odak needs a detector that is differentiable. +This detector could be added as a class in the odak.learn.raytracing submodule, and a new unit test could be added as test/test_learn_detector.py. +To add these to odak, you can rely on the pull request feature on GitHub.

+
+

Rendering scenes

+

Informative · + Practical

+

This section shows how one can use raytracing for rendering purposes in Computer Graphics. +Note that the provided example is simple, aiming to introduce a newcomer to how raytracing could be used for rendering purposes. +The example uses a single perspective camera and relies on a concept called splatting, where rays originate from a camera towards a scene. +The scene is composed of randomly colored triangles, and each time a ray hits a random colored triangle, our perspective camera's corresponding pixel is painted with the color of that triangle. +Let us review our simple example by reading the code and observing its outcome.

+
+
+
+
import sys
+import odak
+import torch
+from tqdm import tqdm
+
+def test(output_directory = 'test_output'):
+    odak.tools.check_directory(output_directory)
+    final_surface_point = torch.tensor([0., 0., 10.])
+    final_surface = odak.learn.raytracing.define_plane(point = final_surface_point)
+    no = [500, 500]
+    start_points, _, _, _ = odak.learn.tools.grid_sample(
+                                                         no = no,
+                                                         size = [10., 10.],
+                                                         center = [0., 0., -10.]
+                                                        )
+    end_point = torch.tensor([0., 0., 0.])
+    rays = odak.learn.raytracing.create_ray_from_two_points(start_points, end_point)
+    mesh = odak.learn.raytracing.planar_mesh(
+                                             size = torch.tensor([10., 10.]),
+                                             number_of_meshes = torch.tensor([40, 40]),
+                                             angles = torch.tensor([  0., -70., 0.]),
+                                             offset = torch.tensor([ -2.,   0., 5.]),
+                                            )
+    triangles = mesh.get_triangles()
+    play_button = torch.tensor([[
+                                 [  1.,  0.5, 3.],
+                                 [  0.,  0.5, 3.],
+                                 [ 0.5, -0.5, 3.],
+                                ]])
+    triangles = torch.cat((play_button, triangles), dim = 0)
+    background_color = torch.rand(3)
+    triangles_color = torch.rand(triangles.shape[0], 3)
+    image = torch.zeros(rays.shape[0], 3) 
+    for triangle_id, triangle in enumerate(triangles):
+        _, _, _, _, check = odak.learn.raytracing.intersect_w_triangle(rays, triangle)
+        check = check.squeeze(0).unsqueeze(-1).repeat(1, 3)
+        color = triangles_color[triangle_id].unsqueeze(0).repeat(check.shape[0], 1)
+        image[check == True] = color[check == True] * check[check == True]
+    image[image == [0., 0., 0]] = background_color
+    image = image.view(no[0], no[1], 3)
+    odak.learn.tools.save_image('{}/image.png'.format(output_directory), image, cmin = 0., cmax = 1.)
+    assert True == True
+
+
+if __name__ == '__main__':
+    sys.exit(test())
+
+
+
+
+
+

Image title +

+
Rendered result for the renderer script of "/test/test_learn_ray_render.py".
+
+

A modern raytracer used in gaming is far more sophisticated than the example we provide here. +There aspects such as material properties or tracing the ray from its source to a camera or allowing rays to interface with multiple materials. +Covering these aspects in a crash course like the one we provide here will take much work. +Instead, we suggest our readers follow the resources provided in other classes, references provided at the end, or any other online available materials.

+

Conclusion

+

Informative

+

We can simulate light on a computer using various methods. +We explain "raytracing" as one of these methods. +Often, raytracing deals with light intensities, omitting many other aspects of light, like the phase or polarization of light. +In addition, sending the right amount of rays from a light source into a scene in raytracing is always a struggle as an outstanding sampling problem. +Raytracing creates many success stories in gaming (e.g., NVIDIA RTX or AMD Radeon Rays) and optical component design (e.g., Zemax or Ansys Speos).

+

Overall, we cover a basic introduction to how to model light as rays and how to use rays to optimize against a given target. +Note that our examples resemble simple cases. +This section aims to provide the readers with a suitable basis to get started with the raytracing of light in simulations. +A dedicated and motivated reader could scale up from this knowledge to advance concepts in displays, cameras, visual perception, optical computing, and many other light-based applications.

+
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+
+
+
    +
  1. +

    GH Spencer and MVRK Murty. General ray-tracing procedure. JOSA, 52(6):672–678, 1962. 

    +
  2. +
  3. +

    Peter Shirley. Ray tracing in one weekend. Amazon Digital Services LLC, 1:4, 2018. 

    +
  4. +
  5. +

    Morgan McGuire. The graphics codex. 2018. 

    +
  6. +
  7. +

    Wenzel Jakob, Sébastien Speierer, Nicolas Roussel, and Delio Vicini. Dr. jit: a just-in-time compiler for differentiable rendering. ACM Transactions on Graphics (TOG), 41(4):1–19, 2022. 

    +
  8. +
  9. +

    Robert J Bell, Kendall R Armstrong, C Stephen Nichols, and Roger W Bradley. Generalized laws of refraction and reflection. JOSA, 59(2):187–189, 1969. 

    +
  10. +
  11. +

    Max Born and Emil Wolf. Principles of optics: electromagnetic theory of propagation, interference and diffraction of light. Elsevier, 2013. 

    +
  12. +
  13. +

    Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. NIPS 2017 Workshop Autodiff, 2017. 

    +
  14. +
  15. +

    Yuliy Schwartzburg, Romain Testuz, Andrea Tagliasacchi, and Mark Pauly. High-contrast computational caustic design. ACM Transactions on Graphics (TOG), 33(4):1–11, 2014. 

    +
  16. +
+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/index.html b/course/index.html new file mode 100644 index 00000000..fafb6cf7 --- /dev/null +++ b/course/index.html @@ -0,0 +1,1920 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Prerequisites and general information - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +
+Narrate section +

+
+

Prerequisites and general information

+

You have reached the website for the Computational Light Course.

+

This page is the starting point for the Computational Light course. +Readers can follow the course material found on these pages to learn more about the field of Computational Light. +I encourage readers to carefully read this page to decide if they want to continue with the course.

+

Brief course description

+

Computational Light is a term that brings the concepts in computational methods with the characteristics of light. In other words, wherever we can program the qualities of light, such as its intensity or direction, this will get us into the topics of Computational Light. Some well-known subfields of Computational Light are Computer Graphics, Computational Displays, Computational Photography, Computational Imaging and Sensing, Computational Optics and Fabrication, Optical Communication, and All-optical Machine Learning.

+
+

Image title +

+
Future is yet to be decided. Will you help me build it? A rendering from Telelife vision paper.
+
+

1

+

Computational Light Course bridges the gap between Computer Science and physics. In other words, Computational Light Course offers students a gateway to get familiar with various aspects of the physics of light, the human visual system, computational methods in designing light-based devices, and applications of light. Precisely, students will familiarize themselves with designing and implementing graphics, display, sensor, and camera systems using state-of-the-art deep learning and optimization methods. A deep understanding of these topics can help students become experts in the computational design of new graphics, displays, sensors, and camera systems.

+

Prerequisites

+

These are the prerequisites of Computational Light course:

+
    +
  • Background knowledge. First and foremost being fluent in programming with Python programming language and a graduate-level understanding of Linear Algebra, and Machine Learning are highly required.
  • +
  • Skills and abilities. Throughout the entire course, three libraries will be used, and these libraries include odak, numpy, and torch. +Familiarity with these libraries is a big plus.
  • +
  • Required Resources. Readers need a computer with decent computational resources (e.g., GPU) when working on the provided materials, laboratory work, and projects. +In case you do not have the right resources, consider using Google's Colab service as it is free to students. +Note that at each section of the course, you will be provided with relevant reading materials on the spot.
  • +
  • Expectations. Readers also need sustainable motivation to learn new things related to the topics of Computational Light, and willing to advance the field by developing, innovating and researching. +In other terms, you are someone motivated to create a positive impact in the society with light related innovations. +You can also be someone eager to understand and learn physics behind light and how you can simulate light related phenomena.
  • +
+

Questions and Answers

+

Here are some questions and answers related to the course that readers may ask:

+
+What is the overarching rationale for the module? +

Historically, physics and electronics departments in various universities study and teach the physics of light. +This way, new light-based devices and equipment have been invented, such as displays, cameras, and fiber networks, in the past, and these devices continuously serve our societies. +However, emerging topics from mathematics and computer science departments, such as deep learning and advanced optimization methods, unlocked new capabilities for existing light-based devices and started to play a crucial role in designing the next generation of these devices. +The Computational Light Course aims to bridge this gap between Computer Science and physics by providing a fundamental understanding of light and computational methods that helps to explore new possibilities with light.

+
+
+Who is the target audience of Computational Light course? +

The Computational Light course is designed for individuals willing to learn how to develop and invent light-based practical systems for next-generation human-computer interfaces. +This course targets a graduate-level audience in Computer Science, Physics and Electrical and Electronics Engineering departments. +However, you do not have to be strictly from one of the highlighted targeted audiences. +Simply put, if you think you can learn and are eager to learn, no one will stop you.

+
+
+How can I learn Python programming, linear Algebra and machine learning? +

There isn't a point in providing references on how to learn Python programming, Linear Algebra, and Machine Learning as there is a vast amount of resources online or in your previous university courses. +Your favorite search engine is your friend in this case.

+
+
+How do I install Python, numpy and torch? +

The installation guide for python, numpy and torch is also available on their websites.

+
+
+How do I install odak? +

Odak's installation page and README provide the most up-to-date information on installing odak. + But in a nutshell, all you need is to use the following command in a terminal pip3 install odak for the latest version, or if you want to install the latest code from the source, use pip3 install git+https://github.com/kaanaksit/odak.

+
+
+Which Python environment and operating system should I use? +

I use the Python distribution shipped with a traditional Linux distribution (e.g., Ubuntu). +Again, there isn't no one correct answer here for everyone. +You can use any operating system (e.g., Windows, Mac) and Python distribution (e.g., conda).

+
+
+Which text editor should I use for programming? +

I use vim as my text editor. +However, I understand that vim could be challenging to adopt, especially as a newcomer. +The pattern I observe among collaborators and students is that they use Microsoft's Visual Studio, a competent text editor with artificial intelligence support through subscription and works across various operating systems. +I encourage you to make your choice depending on how comfortable you are with sharing your data with companies. +Please also remember that I am only making suggestions here. +If another text editor works better for you, please use that one (e.g., nano, Sublime Text, Atom, Notepad++, Jupyter Notebooks).

+
+
+Which terminal program to use? +

You are highly encouraged to use the terminal that you feel most comfortable with. +This terminal could be the default terminal in your operating system. +I use terminator as it enables my workflow with incredible features and is open source.

+
+
+What is the method of delivery? +

The proposed course, Computational Light Course, comprises multiple elements in delivery. We list these elements as the followings:

+
    +
  • Prerequisites and general information. Students will be provided with a written description of requirements related to the course as in this document.
  • +
  • Lectures. The students will attend two hours of classes each week, which will be in-person, virtual, or hybrid, depending on the circumstances (e.g., global pandemic, strikes).
  • +
  • Supplementary Lectures. Beyond weekly classes, students will be encouraged to follow several other sources through online video recordings.
  • +
  • Background review. Students often need a clear development guideline or a stable production pipeline. Thus, in every class and project, a phase of try-and-error causes the student to lose interest in the topic, and often students need help to pass the stage of getting ready for the course and finding the right recipe to complete their work. Thus, we formulate a special session to review the course's basics and requirements. This way, we hope to overcome the challenges related to the "warming up" stage of the class.
  • +
  • Lecture content. We will provide the students with a lecture book composed of chapters. These chapters will be discussed at each weekly lecture. The book chapters will be distributed online using Moodle (requires UCL access), and a free copy of this book will also be reachable without requiring UCL access.
  • +
  • Laboratory work. Students will be provided with questions about their weekly class topics. These questions will require them to code for a specific task. After each class, students will have an hour-long laboratory session to address these questions by coding. The teaching assistants of the lecture will support each laboratory session.
  • +
  • Supporting tools. We continuously develop new tools for the emerging fields of Computational Light. Our development tools will be used in the delivery. These tools are publicly available in Odak, our research toolkit with Mozilla Public License 2.0. Students will get a chance to use these tools in their laboratory works and projects. In the meantime, they will also get the opportunity to contribute to the next versions of the tool.
  • +
  • Project Assignments. Students will be evaluated on their projects. The lecturer will suggest projects related to the topics of Computational Light. However, the students will also be highly encouraged to propose projects for successfully finishing their course. These projects are expected to address a research question related to the topic discussed. Thus, there are multiple components of a project. These are implementation in coding, manuscript in a modern paper format, a website to promote the work to wider audiences, and presentation of the work to other students and the lecturer.
  • +
  • Office hours. There will be office hours for students willing to talk to the course lecturer, Kaan Akşit, in a one-on-one setting. Each week, the lecturer will schedule two hours for such cases.
  • +
+
+
+What is the aim of this course? +

Computational Light Course aims to train individuals that could potentially help invent and develop the next generation of light-based devices, systems and software. To achieve this goal, Computational Light Course, will aim:

+
    +
  • To educate students on physics of light, human visual system and computational methods relevant to physics of light based on optimizations and machine learning techniques,
  • +
  • To enable students the right set of practical skills in coding and design for the next generation of light-based systems,
  • +
  • And to increase literacy on light-based technologies among students and professionals.
  • +
+
+
+What are the intended learning outcomes of this course? +

Students who have completed Computational Light Course successfully will have literacy and practical skills on the following items:

+
    +
  • Physics of Light and applications of Computational Light,
  • +
  • Fundamental knowledge of managing a software project (e.g., version and authoring tools, unit tests, coding style, and grammar),
  • +
  • Fundamental knowledge of optimization methods and state-of-the-art libraries aiming at relevant topics,
  • +
  • Fundamental knowledge of visual perception and the human visual system,
  • +
  • Simulating light as geometric rays, continous waves, and quantum level,
  • +
  • Simulating imaging and displays systems, including Computer-Generated Holography,
  • +
  • Designing and optimizing imaging and display systems,
  • +
  • Designing and optimizing all-optical machine learning systems.
  • +
+

Note that the above list is always subject to change in order or topic as society's needs move in various directions.

+
+
+How to cite this course? +

For citing using latex's bibtex bibliography system: +

@book{aksit2024computationallight,
+  title = {Computational Light},
+  author = {Ak{\c{s}}it, Kaan and Kam, Henry},
+  booktitle = {Computational Light Course Notes},
+  year = {2024}
+}
+
+For plain text citation: Kaan Akşit, "Computational Light Course", 2024.

+
+

Team

+
+ +
+

Kaan Akşit

+

Instructor

+

E-mail +

+
+ +
+

Henry Kam

+

Contributor

+

E-mail +

+
+

Contact Us

+

The preferred way of communication is through the discussions section of odak. +Please only reach us through email if the thing you want to achieve, establish, or ask is not possible through the suggested route.

+
+

Outreach

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+

Acknowledgements

+
+

Acknowledgements

+

We thank our readers. +We also thank Yicheng Zhan for his feedback.

+
+
+

Interested in supporting?

+

Enjoyed our course material and want us to do better in the future? +Please consider supporting us monetarily, citing our work in your next scientific work, or leaving us a star for odak.

+
+
+
+
    +
  1. +

    Jason Orlosky, Misha Sra, Kenan Bektaş, Huaishu Peng, Jeeeun Kim, Nataliya Kos’ myna, Tobias Höllerer, Anthony Steed, Kiyoshi Kiyokawa, and Kaan Akşit. Telelife: the future of remote living. Frontiers in Virtual Reality, 2:763340, 2021. 

    +
  2. +
+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/media/10591010993_80c7cb37a6_c.jpg b/course/media/10591010993_80c7cb37a6_c.jpg new file mode 100644 index 00000000..8007cff6 Binary files /dev/null and b/course/media/10591010993_80c7cb37a6_c.jpg differ diff --git a/course/media/analog_hologram_example.gif b/course/media/analog_hologram_example.gif new file mode 100644 index 00000000..0d580a44 Binary files /dev/null and b/course/media/analog_hologram_example.gif differ diff --git a/course/media/computational_light.mp3 b/course/media/computational_light.mp3 new file mode 100644 index 00000000..37ed6efc Binary files /dev/null and b/course/media/computational_light.mp3 differ diff --git a/course/media/computer_generated_hologram_example.png b/course/media/computer_generated_hologram_example.png new file mode 100644 index 00000000..f38d2a7c Binary files /dev/null and b/course/media/computer_generated_hologram_example.png differ diff --git a/course/media/computer_generated_holography.mp3 b/course/media/computer_generated_holography.mp3 new file mode 100644 index 00000000..fc7818ff Binary files /dev/null and b/course/media/computer_generated_holography.mp3 differ diff --git a/course/media/convolution.png b/course/media/convolution.png new file mode 100644 index 00000000..39b5af39 Binary files /dev/null and b/course/media/convolution.png differ diff --git a/course/media/convolution_animation.gif b/course/media/convolution_animation.gif new file mode 100644 index 00000000..a1f72ef4 Binary files /dev/null and b/course/media/convolution_animation.gif differ diff --git a/course/media/convolve_0.png b/course/media/convolve_0.png new file mode 100644 index 00000000..51894b04 Binary files /dev/null and b/course/media/convolve_0.png differ diff --git a/course/media/convolve_1.png b/course/media/convolve_1.png new file mode 100644 index 00000000..bb67824f Binary files /dev/null and b/course/media/convolve_1.png differ diff --git a/course/media/diffraction_1d.gif b/course/media/diffraction_1d.gif new file mode 100644 index 00000000..c46df099 Binary files /dev/null and b/course/media/diffraction_1d.gif differ diff --git a/course/media/diffraction_2d.gif b/course/media/diffraction_2d.gif new file mode 100644 index 00000000..0819b818 Binary files /dev/null and b/course/media/diffraction_2d.gif differ diff --git a/course/media/emspectrum.png b/course/media/emspectrum.png new file mode 100644 index 00000000..4a7b8b3f Binary files /dev/null and b/course/media/emspectrum.png differ diff --git a/course/media/fundamentals.mp3 b/course/media/fundamentals.mp3 new file mode 100644 index 00000000..bbc8b31a Binary files /dev/null and b/course/media/fundamentals.mp3 differ diff --git a/course/media/geometric_optics.mp3 b/course/media/geometric_optics.mp3 new file mode 100644 index 00000000..eafc8870 Binary files /dev/null and b/course/media/geometric_optics.mp3 differ diff --git a/course/media/git_clone.png b/course/media/git_clone.png new file mode 100644 index 00000000..2183c4e9 Binary files /dev/null and b/course/media/git_clone.png differ diff --git a/course/media/hologram_generation.png b/course/media/hologram_generation.png new file mode 100644 index 00000000..10024bae Binary files /dev/null and b/course/media/hologram_generation.png differ diff --git a/course/media/holographic_display.png b/course/media/holographic_display.png new file mode 100644 index 00000000..d59c4ee5 Binary files /dev/null and b/course/media/holographic_display.png differ diff --git a/course/media/holographic_display_simulation.png b/course/media/holographic_display_simulation.png new file mode 100644 index 00000000..ef9e7cdf Binary files /dev/null and b/course/media/holographic_display_simulation.png differ diff --git a/course/media/image_0000.jpeg b/course/media/image_0000.jpeg new file mode 100644 index 00000000..b616499d Binary files /dev/null and b/course/media/image_0000.jpeg differ diff --git a/course/media/image_0001.jpeg b/course/media/image_0001.jpeg new file mode 100644 index 00000000..4dbd0c39 Binary files /dev/null and b/course/media/image_0001.jpeg differ diff --git a/course/media/image_0002.jpeg b/course/media/image_0002.jpeg new file mode 100644 index 00000000..f70ed4ef Binary files /dev/null and b/course/media/image_0002.jpeg differ diff --git a/course/media/image_0003.jpeg b/course/media/image_0003.jpeg new file mode 100644 index 00000000..d5e58e5e Binary files /dev/null and b/course/media/image_0003.jpeg differ diff --git a/course/media/image_0004.jpeg b/course/media/image_0004.jpeg new file mode 100644 index 00000000..8478e425 Binary files /dev/null and b/course/media/image_0004.jpeg differ diff --git a/course/media/image_0005.jpeg b/course/media/image_0005.jpeg new file mode 100644 index 00000000..ddebd19b Binary files /dev/null and b/course/media/image_0005.jpeg differ diff --git a/course/media/image_0006.jpeg b/course/media/image_0006.jpeg new file mode 100644 index 00000000..be82182c Binary files /dev/null and b/course/media/image_0006.jpeg differ diff --git a/course/media/image_0007.jpeg b/course/media/image_0007.jpeg new file mode 100644 index 00000000..7d5e324a Binary files /dev/null and b/course/media/image_0007.jpeg differ diff --git a/course/media/image_lms_rgb.png b/course/media/image_lms_rgb.png new file mode 100644 index 00000000..3a011e02 Binary files /dev/null and b/course/media/image_lms_rgb.png differ diff --git a/course/media/image_lms_second_stage.png b/course/media/image_lms_second_stage.png new file mode 100644 index 00000000..82872615 Binary files /dev/null and b/course/media/image_lms_second_stage.png differ diff --git a/course/media/image_lms_third_stage.png b/course/media/image_lms_third_stage.png new file mode 100644 index 00000000..d522a04c Binary files /dev/null and b/course/media/image_lms_third_stage.png differ diff --git a/course/media/index.mp3 b/course/media/index.mp3 new file mode 100644 index 00000000..f1613beb Binary files /dev/null and b/course/media/index.mp3 differ diff --git a/course/media/intensities_before_and_after_propagation.png b/course/media/intensities_before_and_after_propagation.png new file mode 100644 index 00000000..d3a9a999 Binary files /dev/null and b/course/media/intensities_before_and_after_propagation.png differ diff --git a/course/media/interference_examples.png b/course/media/interference_examples.png new file mode 100644 index 00000000..6064c20c Binary files /dev/null and b/course/media/interference_examples.png differ diff --git a/course/media/lms_graph.png b/course/media/lms_graph.png new file mode 100644 index 00000000..b8c92732 Binary files /dev/null and b/course/media/lms_graph.png differ diff --git a/course/media/phase_only_hologram_example.png b/course/media/phase_only_hologram_example.png new file mode 100644 index 00000000..8da62e94 Binary files /dev/null and b/course/media/phase_only_hologram_example.png differ diff --git a/course/media/phase_only_hologram_reconstruction_example.png b/course/media/phase_only_hologram_reconstruction_example.png new file mode 100644 index 00000000..f9ee3b48 Binary files /dev/null and b/course/media/phase_only_hologram_reconstruction_example.png differ diff --git a/course/media/photoneb.png b/course/media/photoneb.png new file mode 100644 index 00000000..4b0d9ab7 Binary files /dev/null and b/course/media/photoneb.png differ diff --git a/course/media/photonebalone.png b/course/media/photonebalone.png new file mode 100644 index 00000000..26791f06 Binary files /dev/null and b/course/media/photonebalone.png differ diff --git a/course/media/photonintro.png b/course/media/photonintro.png new file mode 100644 index 00000000..ce7df92e Binary files /dev/null and b/course/media/photonintro.png differ diff --git a/course/media/photonpol.png b/course/media/photonpol.png new file mode 100644 index 00000000..5b2621c5 Binary files /dev/null and b/course/media/photonpol.png differ diff --git a/course/media/photoreceptors_rods_and_cones.png b/course/media/photoreceptors_rods_and_cones.png new file mode 100644 index 00000000..8703985e Binary files /dev/null and b/course/media/photoreceptors_rods_and_cones.png differ diff --git a/course/media/polfilter.png b/course/media/polfilter.png new file mode 100644 index 00000000..b2ff580c Binary files /dev/null and b/course/media/polfilter.png differ diff --git a/course/media/raytracing.png b/course/media/raytracing.png new file mode 100644 index 00000000..fc9f90d4 Binary files /dev/null and b/course/media/raytracing.png differ diff --git a/course/media/reflection.png b/course/media/reflection.png new file mode 100644 index 00000000..c101c67a Binary files /dev/null and b/course/media/reflection.png differ diff --git a/course/media/retinal_photoreceptor_distribution.png b/course/media/retinal_photoreceptor_distribution.png new file mode 100644 index 00000000..a578cba2 Binary files /dev/null and b/course/media/retinal_photoreceptor_distribution.png differ diff --git a/course/media/rods_and_cones_closeup.jpg b/course/media/rods_and_cones_closeup.jpg new file mode 100644 index 00000000..8458a2b1 Binary files /dev/null and b/course/media/rods_and_cones_closeup.jpg differ diff --git a/course/media/sig19_foveated_display_color.png b/course/media/sig19_foveated_display_color.png new file mode 100644 index 00000000..bb26c9ab Binary files /dev/null and b/course/media/sig19_foveated_display_color.png differ diff --git a/course/media/sphere_rays.png b/course/media/sphere_rays.png new file mode 100644 index 00000000..01e797f3 Binary files /dev/null and b/course/media/sphere_rays.png differ diff --git a/course/media/telelife.png b/course/media/telelife.png new file mode 100644 index 00000000..00540f44 Binary files /dev/null and b/course/media/telelife.png differ diff --git a/course/media/zeropad.png b/course/media/zeropad.png new file mode 100644 index 00000000..4551e5b0 Binary files /dev/null and b/course/media/zeropad.png differ diff --git a/course/narrate.py b/course/narrate.py new file mode 100644 index 00000000..babf1253 --- /dev/null +++ b/course/narrate.py @@ -0,0 +1,74 @@ +from TTS.api import TTS +import odak +import sys + + +def main(): + files = sorted(odak.tools.list_files('./', key = '*.md')) + files = ['computer_generated_holography.md'] + tts = TTS(model_name = "tts_models/en/jenny/jenny", progress_bar = True, gpu = True) + cache_fn = 'cache.txt' + wav_file = 'cache.wav' + for file in files: + print(file) + cmd = ['cp', str(file), cache_fn] + odak.tools.shell_command(cmd) + f = open(cache_fn, 'r+') + contents_list = f.readlines() + f.close() + if contents_list == []: + sys.exit() + mp3_file = str(file.replace('.md', '.mp3')) + contents = clear_text(contents_list) + tts.tts_to_file( + text = contents, + file_path = wav_file, + emotion = 'Happy', + speed = 0.8 + ) + cmd = ['ffmpeg', '-i', wav_file, mp3_file, '-y'] + odak.tools.shell_command(cmd) + cmd = ['mv', mp3_file, './media/'] + odak.tools.shell_command(cmd) + cmd = ['rm', cache_fn] + odak.tools.shell_command(cmd) + cmd = ['rm', wav_file] + odak.tools.shell_command(cmd) + + +def clear_text(text): + found_ids = [] + for item_id, item in enumerate(text): + for word in ['', 'mesh3d', 'showlegend']: + if item.find(word) != -1: + found_ids.append(item_id) + new_list = [] + for item_id, item in enumerate(text): + if not item_id in found_ids: + new_list.append(item) + text = new_list + output_text = ''.join(text) + output_text = output_text.replace('???', '') + output_text = output_text.replace('Narrate section', '') + output_text = output_text.replace(':material-alert-decagram:{ .mdx-pulse title="Too important!" }', 'Too important!') + output_text = output_text.replace(':octicons-beaker-24:', '') + output_text = output_text.replace(':octicons-info-24:', '') + output_text = output_text.replace('quote end', '') + output_text = output_text.replace('question end', '') + output_text = output_text.replace('information end', '') + output_text = output_text.replace('success end', '') + output_text = output_text.replace('Warning end', '') + output_text = output_text.replace('!!!', '') + output_text = output_text.replace('#', '') + output_text = output_text.replace('##', '') + output_text = output_text.replace('###', '') + output_text = output_text.replace('####', '') + output_text = output_text.replace('$', '') + output_text = output_text.replace('$$', '') + output_text = output_text.replace('*', '') + output_text = output_text.replace('**', '') + return output_text + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/course/optimize_sample.py b/course/optimize_sample.py new file mode 100644 index 00000000..f4d9aa02 --- /dev/null +++ b/course/optimize_sample.py @@ -0,0 +1,29 @@ +import torch +import odak +import sys + + +def forward(x, m, n): + y = m * x + n + return y + + +def main(): + m = torch.tensor([100.], requires_grad = True) + n = torch.tensor([0.], requires_grad = True) + x_vals = torch.tensor([1., 2., 3., 100.]) + y_vals = torch.tensor([5., 6., 7., 101.]) + optimizer = torch.optim.Adam([m, n], lr = 5e1) + loss_function = torch.nn.MSELoss() + for step in range(1000): + optimizer.zero_grad() + y_estimate = forward(x_vals, m, n) + loss = loss_function(y_estimate, y_vals) + loss.backward(retain_graph = True) + optimizer.step() + print('Step: {}, Loss: {}'.format(step, loss.item())) + print(m, n) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/course/photonic_computers/index.html b/course/photonic_computers/index.html new file mode 100644 index 00000000..17af9b18 --- /dev/null +++ b/course/photonic_computers/index.html @@ -0,0 +1,1683 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + All-optical Machine Learning - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

All-optical Machine Learning

+ + + + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/ray.txt b/course/ray.txt new file mode 100644 index 00000000..0b3f9fbe --- /dev/null +++ b/course/ray.txt @@ -0,0 +1,2 @@ +
+
\ No newline at end of file diff --git a/course/source/animation_convolution.py b/course/source/animation_convolution.py new file mode 100644 index 00000000..81664c6a --- /dev/null +++ b/course/source/animation_convolution.py @@ -0,0 +1,31 @@ +import odak +import torch +import sys + + +def main(): + filename_image = '../media/10591010993_80c7cb37a6_c.jpg' + image = odak.learn.tools.load_image(filename_image, normalizeby = 255., torch_style = True)[0:3].unsqueeze(0) + kernel = odak.learn.tools.generate_2d_gaussian(kernel_length = [12, 12], nsigma = [21, 21]) + kernel = kernel / kernel.max() + result = torch.zeros_like(image) + result = odak.learn.tools.zero_pad(result, size = [image.shape[-2] + kernel.shape[0], image.shape[-1] + kernel.shape[1]]) + step = 0 + for i in range(image.shape[-2]): + for j in range(image.shape[-1]): + for ch in range(image.shape[-3]): + element = image[:, ch, i, j] + add = kernel * element + result[:, ch, i : i + kernel.shape[0], j : j + kernel.shape[1]] += add + if (i * image.shape[-1] + j) % 1e4 == 0: + filename = 'step_{:04d}.png'.format(step) + odak.learn.tools.save_image( filename, result, cmin = 0., cmax = 100.) + step += 1 + cmd = ['convert', '-delay', '1', '-loop', '0', '*.png', '../media/convolution_animation.gif'] + odak.tools.shell_command(cmd) + cmd = ['rm', '*.png'] + odak.tools.shell_command(cmd) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/course/source/convolution.svg b/course/source/convolution.svg new file mode 100644 index 00000000..0708ca4c --- /dev/null +++ b/course/source/convolution.svg @@ -0,0 +1,3356 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + * + 1 + 4 + 7 + 9 + 2 + 1 + x + + + 5 + 8 + 2 + 3 + 6 + + + + + 9 + 2 + 0 + 4 + 5 + + + + + 2 + 3 + 1 + 2 + 7 + + + + + 3 + 6 + 3 + 5 + 4 + + + -1 + -2 + -3 + -4 + 5 + 3 + 2 + 7 + 9 + A + R + K + + + + + + -1 + -2 + -3 + -4 + 5 + 3 + 2 + 7 + 9 + K + + + + + + + + + + + + 1 + 4 + 7 + 9 + 2 + + + + 5 + 8 + 2 + 3 + 6 + + + + + 9 + 2 + 0 + 4 + 5 + + + + + 2 + 3 + 1 + 2 + 7 + + + + + 3 + 6 + 3 + 5 + 4 + + + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + + + + + + + + + + + + + + + 5 + 9 + 0 + 0 + 0 + 7 + -2 + 0 + 0 + 0 + + + 0 + 0 + 0 + 0 + 0 + + + + + 0 + 0 + 0 + 0 + 0 + + + + + 0 + 0 + 0 + 0 + 0 + + + 2 + -3 + 0 + 0 + 0 + 0 + -1 + 3 + -4 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + + + + + + + A[0, 0] + 5 + x + R + + + + + + -1 + -2 + -3 + -4 + 5 + 3 + 2 + 7 + 9 + K + + + + + + + + + + 20 + -11 + 0 + 0 + 0 + 32 + 43 + 0 + 0 + 0 + 35 + -10 + 0 + 0 + 0 + + + 0 + 0 + 0 + 0 + 0 + + + + + 0 + 0 + 0 + 0 + 0 + + + -3 + 7 + -15 + 0 + 0 + 0 + -1 + 3 + -4 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + 0 + + + + + + + A[1, 0] + + + + + + diff --git a/course/source/drawings.svg b/course/source/drawings.svg new file mode 100644 index 00000000..f1880370 --- /dev/null +++ b/course/source/drawings.svg @@ -0,0 +1,2760 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + Z + X + Y + + + Photon + + + + + + Z + X + Y + + + Photon + + + + + + E(r,t) + B(r,t) + + + + + + + + Z + X + Y + + + Photon + + + Polarization axis + λ + + + + + + + + + Z + X + Y + + + Photon + + + + + + + + + + Z + X + Y + + + Photon + + + + + + + + + + + + + + + + + Z + X + Y + + + Photon + + + Polarization axis + + + + + + + + + + + + + Z + X + Y + + + Photon + + + Polarization axis + + + + + + + + + + + + + + Polarizationfilter + + Polarizationfilter + + + + Collimated lightsource + + + Path of the light + + + + α + α + Dielectric mirror + + + + + Collimated lightsource + Lambertian Surface + + + + + + + + + + + + + + Path of the light + + ApertureSize + + a + + + + 45ᐤ + 45ᐤ + + 2D Retro-reflector + Collimated lightsource 1 + + + + + diff --git a/course/source/emspectrum.svg b/course/source/emspectrum.svg new file mode 100644 index 00000000..68feae31 --- /dev/null +++ b/course/source/emspectrum.svg @@ -0,0 +1,379 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Buildings + + + + + + + + + Humans + + Butterflies + Needle Point + Protozoans + Molecules + Atoms + Atomic Nuclei + + + + + + + + + + + 10 + 4 + + + 10 + 8 + + 10 + 12 + + + 10 + 15 + + + + 10 + 16 + + + 10 + 18 + + + 10 + 20 + + + + + + + 1 K + 100 K + 10,000 K + 10,000,000 K + + Penetrates Earth'sAtmosphere? + + Radio + Microwave + Infrared + Visible + Ultraviolet + X-ray + Gamma ray + + 10 + 3 + + + 10 + −2 + + + 10 + −5 + + + 0.5×10 + −6 + + + 10 + −8 + + + 10 + −10 + + + 10 + −12 + + Radiation Type + Wavelength (m) + + Approximate Scaleof Wavelength + Frequency (Hz) + Temperature ofobjects at which this radiation is themost intensewavelength emitted + −272 °C + −173 °C + 9,727 °C + ~10,000,000 °C + + \ No newline at end of file diff --git a/course/source/photoreceptors_rods_and_cones.svg b/course/source/photoreceptors_rods_and_cones.svg new file mode 100644 index 00000000..36f3fe1c --- /dev/null +++ b/course/source/photoreceptors_rods_and_cones.svg @@ -0,0 +1,20841 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/source/photoreceptors_rods_and_cones_license.pdf b/course/source/photoreceptors_rods_and_cones_license.pdf new file mode 100644 index 00000000..1e77fdec Binary files /dev/null and b/course/source/photoreceptors_rods_and_cones_license.pdf differ diff --git a/course/source/retinal_photoreceptor_distribution.svg b/course/source/retinal_photoreceptor_distribution.svg new file mode 100644 index 00000000..f849a29e --- /dev/null +++ b/course/source/retinal_photoreceptor_distribution.svg @@ -0,0 +1,8493 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/course/visual_perception/index.html b/course/visual_perception/index.html new file mode 100644 index 00000000..e1cd56e7 --- /dev/null +++ b/course/visual_perception/index.html @@ -0,0 +1,2036 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Visual Perception - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Visual Perception

+ +

Color Perception

+

Informative · + Practical

+

We can establish an understanding on color perception through studying its physical and perceptual meaning. +This way, we can gather more information on its relation to technologies and devices including displays, cameras, sensors, communication devices, computers and computer graphics.

+

Color, a perceptual phenomenon, can be explained in a physical and visual perception capacity. +In the physical sense, color is a quantity representing the response to wavelength of light. +The human visual system can perceive colors within a certain range of the electromagnetic spectrum, from around 400 nanometers to 700 nanometers. +For greater details on the electromagnetic spectrum and concept of wavelength, we recommend revisiting Light, Computation, and Computational Light section of our course. +For the human visual system, color is a perceptual phenomenon created by our brain when specific wavelengths of light are emitted, reflected, or transmitted by objects. +The perception of color originates from the absorption of light by photoreceptors in the eye. +These photoreceptor cells convert the light into electrical signals to be interpreted by the brain1. +Here, you can see a close-up photograph of these photoreceptor cells found in the eye.

+
+

Image title +

+
Micrograph of retinal photoreceptor cells, with rods and cones highlighted in green (top row). Image courtesy of NIH, licensed under CC PDM 1.0. View source.
+
+

The photoreceptors, where color perception originates, are called rods and cones2. +Here, we provide a sketch showing where these rods and cones are located inside the eye. +By closely observing this sketch, you can also understand the basic average geometry of a human eye and its parts helping to redirect light from an actual scene towards retinal cells.

+
+

Image title +

+
Anatomy of an Eye (Designed with BioRender.com).
+
+

Rods, which are relatively more common in the periphery, help people see in low-light (scotopic) conditions. +The current understanding is that the roids can only interpret in a greyscale manner. +Cones, which are more dense in the fovea, are pivotal in color perception in brighter (photopic) environments. +We highlight the distribution of these photoreceptor cells, rods and cones with changing eccentricities in the eye. +Here, the word eccentricities refer to angles with respect to our gaze direction. +For instance, if a person is not directly gazing at a location or an object in a given scene, that location or the object would have some angle to the gaze of that person. +Thus, there would be at some angles, some eccentricity between the gaze of that person and that location or object in that scene.

+
+

Image title +

+
Retinal Photoreceptor Distribution, adapted from the work by Goldstein et al [3].
+
+

In the above sketch, we introduced various parts on the retina, including fovea, parafovea, perifovea and peripheral vision. +Note that these regions are defined by the angles, in other words eccentricities. +Please also note that there is a region on our retina where there are no rods and cones are available. +This region could be found in every human eye and known as the blind spot on the retina. +Visual acuity and contrast sensitivity decreases progressively across these identified regions, with the most detail in the fovea, diminishing toward the periphery.

+
+

Image title +

+
Spectral Sensitivities of LMS cones
+
+

The cones are categorized into three types based on their sensitivity to specific wavelengths of light, corresponding to long (L), medium (M), and short (S) wavelength cones. These three types of cones3 allow us to better understand the trichromatic theory4, suggesting that human color perception stems from combining stimulations of the LMS cones. Scientists have tried to graphically represent how sensitive each type of cone is to different wavelengths of light, which is known as the spectral sensitivity function5. In practical applications such as display technologies and computational imaging, the LMS cone response can be replicated with the following formula:

+
\[ +LMS = \sum_{i=1}^{3} \text{RGB}_i \cdot \text{Spectrum}_i \cdot \text{Sensitivity}_i +\]
+

Where:

+
    +
  • \(RGB_i\): The i-th color channel (Red, Green, or Blue) of the image.
  • +
  • \(Spectrum_i\): The spectral distribution of the corresponding primary
  • +
  • \(Sensitivity_i\): The sensitivity of the L, M, and S cones for each wavelength.
  • +
+

This formula gives us more insight on how we percieve colors from different digital and physical inputs.

+
+Looking for more reading to expand your understanding on human visual system? +

We recommend these papers, which we find it insightful: +
- B. P. Schmidt, M. Neitz, and J. Neitz, "Neurobiological hypothesis of color appearance and hue perception," J. Opt. Soc. Am. A 31(4), A195–207 (2014) +
- Biomimetic Eye Modeling & Deep Neuromuscular Oculomotor Control

+
+

The story of color perception only deepens with the concept of color opponency6. +This theory reveals that our perception of color is not just a matter of additive combinations of primary colors but also involves a dynamic interplay of opposing colors: red versus green, blue versus yellow. +This phenomenon is rooted in the neural pathways of the eye and brain, where certain cells are excited or inhibited by specific wavelengths, enhancing our ability to distinguish between subtle shades and contrasts. +Below is a mathematical formulation for the color opponency model proposed by Schmidt et al.3

+
\[\begin{bmatrix} +I_{(M+S)-L} \\ +I_{(L+S)-M} \\ +I_{(L+M+S)} +\end{bmatrix} += +\begin{bmatrix} +(I_M + I_S) - I_L \\ +(I_L + I_S) - I_M \\ +(I_L, I_M, I_S) +\end{bmatrix}\]
+

In this equation, \(I_L\), \(I_M\), and \(I_S\) represent the intensities received by the long, medium, and short cone cells, respectively. Opponent signals are represented by the differences between combinations of cone responses.

+

We could exercise on our understanding of trichromat sensation with LMS cones and the concept of color opponency by vising the functions available in our toolkit, odak. +The utility function we will review is odak.learn.perception.display_color_hvs.primarier_to_lms() from odak.learn.perception. +Let us use this test to demonstrate how we can obtain LMS sensation from the color primaries of an image.

+
+
+
+
+
+
import odak # (1)
+import torch
+import sys
+from odak.learn.perception.color_conversion import display_color_hvs
+
+
+def test(
+         device = torch.device('cpu'),
+         output_directory = 'test_output'
+        ):
+    odak.tools.check_directory(output_directory)
+    torch.manual_seed(0)
+
+    image_rgb = odak.learn.tools.load_image(
+                                            'test/data/fruit_lady.png',
+                                            normalizeby = 255.,
+                                            torch_style = True
+                                           ).unsqueeze(0).to(device) # (2)
+
+    the_number_of_primaries = 3
+    multi_spectrum = torch.zeros(
+                                 the_number_of_primaries,
+                                 301
+                                ) # (3)
+    multi_spectrum[0, 200:250] = 1.
+    multi_spectrum[1, 130:145] = 1.
+    multi_spectrum[2, 0:50] = 1.
+
+    display_color = display_color_hvs(
+                                      read_spectrum ='tensor',
+                                      primaries_spectrum=multi_spectrum,
+                                      device = device
+                                     ) # (4)
+
+    image_lms_second_stage = display_color.primaries_to_lms(image_rgb) # (5)
+    image_lms_third_stage = display_color.second_to_third_stage(image_lms_second_stage) # (6)
+
+
+    odak.learn.tools.save_image(
+                                '{}/image_rgb.png'.format(output_directory),
+                                image_rgb,
+                                cmin = 0.,
+                                cmax = image_rgb.max()
+                               )
+
+
+    odak.learn.tools.save_image(
+                                '{}/image_lms_second_stage.png'.format(output_directory),
+                                image_lms_second_stage,
+                                cmin = 0.,
+                                cmax = image_lms_second_stage.max()
+                               )
+
+    odak.learn.tools.save_image(
+                                '{}/image_lms_third_stage.png'.format(output_directory),
+                                image_lms_third_stage,
+                                cmin = 0.,
+                                cmax = image_lms_third_stage.max()
+                               )
+
+
+    image_rgb_noisy = image_rgb * 0.6 + torch.rand_like(image_rgb) * 0.4 # (7)
+    loss_lms = display_color(image_rgb, image_rgb_noisy) # (8)
+    print('The third stage LMS sensation difference between two input images is {:.10f}.'.format(loss_lms))
+    assert True == True
+
+if __name__ == "__main__":
+    sys.exit(test())
+
+
    +
  1. Adding odak to our imports.
  2. +
  3. Loading an existing RGB image.
  4. +
  5. Defining the spectrum of our primaries of our imaginary display. These values are defined for each primary from 400 nm to 701 nm (301 elements).
  6. +
  7. Obtain LMS cone sensations for our primaries of our imaginary display.
  8. +
  9. Calculating the LMS sensation of our input RGB image at the second stage of color perception using our imaginary display.
  10. +
  11. Calculating the LMS sensation of our input RGB image at the third stage of color perception using our imaginary display.
  12. +
  13. We are intentionally adding some noise to the input RGB image here.
  14. +
  15. We calculate the perceptual loss/difference between the two input image (original RGB vs noisy RGB). +
    + This a visualization of a randomly generated image and its' LMS cone sensation.
  16. +
+

Our code above saves three different images. +The very first saved image is the ground truth RGB image as depicted below.

+
+

Image title +

+
Original ground truth image.
+
+

We process this ground truth image by accounting human visual system's cones and display backlight spectrum. +This way, we can calculate how our ground truth image is sensed by LMS cones. +The LMS sensation, in other words, ground truth image in LMS color space is provided below. +Note that each color here represent a different cone, for instance, green color channel of below image represents medium cone and blue channel represents short cones. +Keep in mind that LMS sensation is also known as trichromat sensation in the literature.

+
+

Image title +

+
Image in LMS cones trichromat space.
+
+

Earlier, we discussed about the color oppenency theory. +We follow this theory, and with our code, we utilize trichromat values to derive an image representation below.

+
+

Image title +

+
Image representation of color opponency.
+
+
+Lab work: Observing the effect of display spectrum +

We introduce our unit test, test_learn_perception_display_color_hvs.py, to provide an example on how to convert an RGB image to trichromat values as sensed by the retinal cone cells. +Note that during this exercise, we define a variable named multi_spectrum to represent the wavelengths of our each color primary. +These wavelength values are stored in a vector for each primary and provided the intensity of a corresponding wavelength from 400 nm to 701 nm. +The trichromat values that we have derived from our original ground truth RGB image is highly correlated with these spectrum values. +To observe this correlation, we encourage you to find spectrums of actual display types (e.g., OLEDs, LEDs, LCDs) and map the multi_spectrum to their spectrum to observe the difference in color perception in various display technologies. +In addition, we believe that this will also give you a practical sandbox to examine the correlation between wavelengths and trichromat values.

+
+ + +

Closing remarks

+

As we dive deeper into light and color perception, it becomes evident that the task of replicating the natural spectrum of colors in technology is still an evolving journey. +This exploration into the nature of color sets the stage for a deeper examination of how our biological systems perceive color and how technology strives to emulate that perception.

+
+Consider revisiting this chapter +

Remember that you can always revisit this chapter as you progress with the course and as you need it. +This chapter is vital for establishing a means to complete your assignments and could help formulate a suitable base to collaborate and work with my research group in the future or other experts in the field.

+
+
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+
+
+
    +
  1. +

    Jeremy Freeman and Eero P Simoncelli. Metamers of the ventral stream. Nature Neuroscience, 14:1195–1201, 2011. doi:10.1038/nn.2889

    +
  2. +
  3. +

    Trevor D Lamb. Why rods and cones? Eye, 30:179–185, 2015. doi:10.1038/eye.2015.236

    +
  4. +
  5. +

    Brian P Schmidt, Maureen Neitz, and Jay Neitz. Neurobiological hypothesis of color appearance and hue perception. Journal of the Optical Society of America A, 31(4):A195–A207, 2014. doi:10.1364/JOSAA.31.00A195

    +
  6. +
  7. +

    H. V. Walters. Some experiments on the trichromatic theory of vision. Proceedings of the Royal Society of London. Series B - Biological Sciences, 131:27–50, 1942. doi:10.1098/rspb.1942.0016

    +
  8. +
  9. +

    Andrew Stockman and Lindsay T Sharpe. The spectral sensitivities of the middle- and long-wavelength-sensitive cones derived from measurements in observers of known genotype. Vision Research, 40:1711–1737, 2000. doi:10.1016/S0042-6989(00)00021-3

    +
  10. +
  11. +

    Steven K Shevell and Paul R Martin. Color opponency: tutorial. Journal of the Optical Society of America A, 34(8):1099–1110, 2017. doi:10.1364/JOSAA.34.001099

    +
  12. +
+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/index.html b/index.html new file mode 100644 index 00000000..75dbc9b5 --- /dev/null +++ b/index.html @@ -0,0 +1,1758 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + Introduction - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Getting started

+

Informative

+

Odak (pronounced "O-dac") is the fundamental library for scientific computing in optical sciences, computer graphics, and visual perception. +We designed this page to help first-time users, new contributors, and existing users understand where to go within this documentation when they need help with certain aspects of Odak. +If you think you need a refresher or are a beginner willing to learn more about light and computation, we created an entire course named Computational Light for you to get to pace with the computational aspects of light.

+

Absolute Beginners

+

Informative · + Practical

+

Computational Light Course: Learn Odak and Physics of Light

+

New Users

+

Informative

+ +

Use cases

+

Informative

+ +

New contributors

+

Informative

+ +

Additional information

+

Informative

+ +
+

Reminder

+

We host a Slack group with more than 300 members. +This Slack group focuses on the topics of rendering, perception, displays and cameras. +The group is open to public and you can become a member by following this link. +Readers can get in-touch with the wider community using this public group.

+
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/installation/index.html b/installation/index.html new file mode 100644 index 00000000..736274d9 --- /dev/null +++ b/installation/index.html @@ -0,0 +1,1827 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Installation - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Installation

+

We use odak with Linux operating systems. +Therefore, we don't know if it can work with Windows or Mac operating systems. +Odak can be installed in multiple ways. +However, our recommended method for installing Odak is using pip distribution system. +We update Odak within pip with each new version. +Thus, the most straightforward way to install Odak is to use the below command in a Linux shell:

+

pip3 install odak
+
+Note that Odak is in constant development. +One may want to install the latest and greatest odak in the source repository for their reasons. +In this case, our recommended method is to rely on pip for installing Odak from the source using:

+
pip3 install git+https://github.com/kaanaksit/odak
+
+

One can also install Odak without pip by first getting a local copy and installing using Python. +Such an installation can be conducted using:

+
git clone git@github.com:kaanaksit/odak.git
+cd odak
+pip3 install -r requirements.txt
+pip3 install -e .
+
+

Uninstalling the Development version

+

If you have to remove the development version of odak, you can first try:

+
pip3 uninstall odak
+sudo pip3 uninstall odak
+
+

And if for some reason, you are still able to import odak after that, check easy-install.pth file which is typically found ~/.local/lib/pythonX/site-packages, where ~ refers to your home directory and X refers to your Python version. +In that file, if you see odak's directory listed, delete it. +This will help you remove development version of odak.

+

Notes before running

+

Some notes should be highlighted to users, and these include:

+
    +
  • Odak installs PyTorch that only uses CPU. +To properly install PyTorch with GPU support, please consult PyTorch website.
  • +
+

Testing an installation

+

After installing Odak, one can test if Odak has been appropriately installed with its dependencies by running the unit tests. +To be able to run unit tests, make sure to have pytest installed:

+
pip3 install -U pytest
+
+

Once pytest is installed, unit tests can be run by calling:

+

cd odak
+pytest
+
+The tests should return no error. +However, if an error is encountered, please start a new issue to help us be aware of the issue.

+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/javascripts/config.js b/javascripts/config.js new file mode 100644 index 00000000..06dbf38b --- /dev/null +++ b/javascripts/config.js @@ -0,0 +1,16 @@ +window.MathJax = { + tex: { + inlineMath: [["\\(", "\\)"]], + displayMath: [["\\[", "\\]"]], + processEscapes: true, + processEnvironments: true + }, + options: { + ignoreHtmlClass: ".*|", + processHtmlClass: "arithmatex" + } +}; + +document$.subscribe(() => { + MathJax.typesetPromise() +}) diff --git a/lensless/index.html b/lensless/index.html new file mode 100644 index 00000000..95a1309c --- /dev/null +++ b/lensless/index.html @@ -0,0 +1,1693 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Getting Started - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Lensless Imaging

+

Odak contains essential ingredients for research and development targeting Lensless Imaging.

+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/machine_learning/index.html b/machine_learning/index.html new file mode 100644 index 00000000..b78b9bb0 --- /dev/null +++ b/machine_learning/index.html @@ -0,0 +1,1694 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Introduction - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Machine learning

+

Odak provides a set of function that implements classical methods in machine learning. +Note that these functions are typically basing on Numpy. +Thus, they do not take advantage from automatic differentiation found in Torch. +The soul reason why these functions exists is because they stand as an example for impelementing basic methods in machine learning.

+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/notes/holographic_light_transport/index.html b/notes/holographic_light_transport/index.html new file mode 100644 index 00000000..faa27b2d --- /dev/null +++ b/notes/holographic_light_transport/index.html @@ -0,0 +1,1794 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Holographic Light Transport - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Holographic light transport

+

Odak contains essential ingredients for research and development targeting Computer-Generated Holography. +We consult the beginners in this matter to Goodman's Introduction to Fourier Optics book (ISBN-13: 978-0974707723) and Principles of optics: electromagnetic theory of propagation, interference and diffraction of light from Max Born and Emil Wolf (ISBN 0-08-26482-4). +This engineering note will provide a crash course on how light travels from a phase-only hologram to an image plane.

+ + + + + + + + + + + +
Holographic image reconstruction. A collimated beam with a homogenous amplitude distribution (A=1) illuminates a phase-only hologram \(u_0(x,y)\). Light from this hologram diffracts and arrive at an image plane \(u(x,y)\) at a distance of z. Diffracted beams from each hologram pixel interfere at the image plane and, finally, reconstruct a target image.
+

As depicted in above figure, when such holograms are illuminated with a collimated coherent light (e.g. laser), these holograms can reconstruct an intended optical field at target depth levels. +How light travels from a hologram to a parallel image plane is commonly described using Rayleigh-Sommerfeld diffraction integrals (For more, consult Heurtley, J. C. (1973). Scalar Rayleigh–Sommerfeld and Kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8), 1003-1008.). +The first solution of the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle, is expressed as follows:

+

\(u(x,y)=\frac{1}{j\lambda} \int\!\!\!\!\int u_0(x,y)\frac{e^{jkr}}{r}cos(\theta)dxdy,\)

+

where field at a target image plane, \(u(x,y)\), is calculated by integrating over every point of hologram's field, \(u_0(x,y)\). +Note that, for the above equation, \(r\) represents the optical path between a selected point over a hologram and a selected point in the image plane, theta represents the angle between these two points, k represents the wavenumber (\(\frac{2\pi}{\lambda}\)) and \(\lambda\) represents the wavelength of light. +In this described light transport model, optical fields, \(u_0(x,y)\) and \(u(x,y)\), are represented with a complex value,

+

\(u_0(x,y)=A(x,y)e^{j\phi(x,y)},\)

+

where A represents the spatial distribution of amplitude and \(\phi\) represents the spatial distribution of phase across a hologram plane. +The described holographic light transport model is often simplified into a single convolution with a fixed spatially invariant complex kernel, \(h(x,y)\) (Sypek, Maciej. "Light propagation in the Fresnel region. New numerical approach." Optics communications 116.1-3 (1995): 43-48.).

+

\(u(x,y)=u_0(x,y) * h(x,y) =\mathcal{F}^{-1}(\mathcal{F}(u_0(x,y)) \mathcal{F}(h(x,y)))\)

+

There are multiple variants of this simplified approach:

+
    +
  • Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.,
  • +
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range." Optics letters 45.6 (2020): 1543-1546.,
  • +
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product." Optics Letters 45.16 (2020): 4416-4419.
  • +
+

In many cases, people choose to use the most common form of h described as

+

\(h(x,y)=\frac{e^{jkz}}{j\lambda z} e^{\frac{jk}{2z} (x^2+y^2)},\)

+

where z represents the distance between a hologram plane and a target image plane. +Note that beam propagation can also be learned for physical setups to avoid imperfections in a setup and to improve the image quality at an image plane:

+
    +
  • Peng, Yifan, et al. "Neural holography with camera-in-the-loop training." ACM Transactions on Graphics (TOG) 39.6 (2020): 1-14.,
  • +
  • Chakravarthula, Praneeth, et al. "Learned hardware-in-the-loop phase retrieval for holographic near-eye displays." ACM Transactions on Graphics (TOG) 39.6 (2020): 1-18.,
  • +
  • Kavaklı, Koray, Hakan Urey, and Kaan Akşit. "Learned holographic light transport." Applied Optics (2021)..
  • +
+

See also

+

For more engineering notes, follow:

+ + + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/notes/holographic_light_transport_files/hologram_generation.png b/notes/holographic_light_transport_files/hologram_generation.png new file mode 100644 index 00000000..10024bae Binary files /dev/null and b/notes/holographic_light_transport_files/hologram_generation.png differ diff --git a/notes/optimizing_holograms_using_odak/index.html b/notes/optimizing_holograms_using_odak/index.html new file mode 100644 index 00000000..36388e65 --- /dev/null +++ b/notes/optimizing_holograms_using_odak/index.html @@ -0,0 +1,1901 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Hologram Optimization - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Optimizing holograms using Odak

+

This engineering note will give you an idea about how to optimize phase-only holograms using Odak. +We consult the beginners in this matter to Goodman's Introduction to Fourier Optics (ISBN-13: 978-0974707723) and Principles of optics: electromagnetic theory of propagation, interference and diffraction of light from Max Born and Emil Wolf (ISBN 0-08-26482-4). +Note that the creators of this documentation are from the Computational Displays domain. +However, the provided submodules can potentially aid other lines of research as well, such as Computational Imaging or Computational Microscopy.

+

The optimization that is referred to in this document is the one that generates a phase-only hologram that can reconstruct a target image. +There are multiple ways in the literature to optimize a phase-only hologram for a single plane, and these include:

+

Gerchberg-Saxton and Yang-Yu algorithms: +- Yang, G. Z., Dong, B. Z., Gu, B. Y., Zhuang, J. Y., & Ersoy, O. K. (1994). Gerchberg–Saxton and Yang–Gu algorithms for phase retrieval in a nonunitary transform system: a comparison. Applied optics, 33(2), 209-218.

+

Stochastic Gradient Descent based optimization: +- Chen, Y., Chi, Y., Fan, J., & Ma, C. (2019). Gradient descent with random initialization: Fast global convergence for nonconvex phase retrieval. Mathematical Programming, 176(1), 5-37.

+

Odak provides functions to optimize phase-only holograms using Gerchberg-Saxton algorithm or the Stochastic Gradient Descent based approach. +The relevant functions here are odak.learn.wave.stochastic_gradient_descent and odak.learn.wave.gerchberg_saxton. +We will review both of these definitions in this document. +But first, let's get prepared.

+

Preparation

+

We first start with imports, here is all you need:

+
from odak.learn.wave import stochastic_gradient_descent, calculate_amplitude, calculate_phase
+import torch
+
+

We will also be needing some variables that defines the wavelength of light that we work with:

+
wavelength = 0.000000532
+
+

Pixel pitch and resolution of the phase-only hologram or a phase-only spatial light modulator that we are simulating:

+
dx = 0.0000064
+resolution = [1080, 1920]
+
+

Define the distance that the light will travel from optimized hologram.

+
distance = 0.15
+
+

We have to set a target image. +You can either load a sample image here or paint a white rectangle on a white background like in this example.

+
target = torch.zeros(resolution[0],resolution[1])
+target[500:600,400:450] = 1.
+
+

Surely, we also have to set the number of iterations and learning rate for our optimizations. +If you want the GPU support, you also have to set the cuda as True. +Propagation type has to be defined as well. +In this example, we will use transfer function Fresnel approach. +For more on propagation types, curious readers can consult +Computational Fourier Optics David Vuelz (ISBN13:9780819482044).

+
iteration_number = 100
+learning_rate = 0.1
+cuda = True
+propagation_type = 'TR Fresnel'
+
+

This step concludes our preparations. +Let's dive into optimizing our phase-only holograms. +Depending on your choice, you can either optimize using Gerchberg-Saxton approach or the Stochastic Gradient Descent approach. +This document will only show you Stochastic Gradient Descent approach as it is the state of art. +However, optimizing a phase-only hologram is as importing:

+
from odak.learn.wave import gerchberg_saxton
+
+

and almost as easy as replacing stochastic_gradient_descent with gerchberg_saxton in the upcoming described hologram routine. +For greater details, consult to documentation of odak.learn.wave.

+

Stochastic Gradient Descent approach

+

We have prepared a function for you to avoid compiling a differentiable hologram optimizer from scratch.

+
hologram, reconstructed = stochastic_gradient_descent(
+        target,
+        wavelength,
+        distance,
+        dx,
+        resolution,
+        'TR Fresnel',
+        iteration_number,
+        learning_rate=learning_rate,
+        cuda=cuda
+    )
+
+
Iteration: 99 loss:0.0003
+
+

Congratulations! You have just optimized a phase-only hologram that reconstruct your target image at the target depth.

+

Surely, you want to see what kind of image is being reconstructed with this newly optimized hologram. +You can save the outcome to an image file easily. +Odak provides tools to save and load images. +First, you have to import:

+
from odak.learn.tools import save_image,load_image
+
+

As you can recall, we have created a target image earlier that is normalized between zero and one. +The same is true for our result, reconstructed. +Therefore, we have to save it correctly by taking that into account. +Note that reconstructed is the complex field generated by our optimized hologram variable. +So, we need to save the reconstructed intensity as humans and cameras capture intensity but not a complex field with phase and amplitude.

+
reconstructed_intensity = calculate_amplitude(reconstructed)**2
+save_image('reconstructed_image.png',reconstructed_intensity,cmin=0.,cmax=1.)
+
+
True
+
+

To save our hologram as an image so that we can load it to a spatial light modulator, we have to normalize it between zero and 255 (dynamic range of a typical image on a computer).

+

P.S. Depending on your SLM's calibration and dynamic range things may vary.

+
slm_range = 2*3.14
+dynamic_range = 255
+phase_hologram = calculate_phase(hologram)
+phase_only_hologram = (phase_hologram%slm_range)/(slm_range)*dynamic_range
+
+

It is now time for saving our hologram:

+
save_image('phase_only_hologram.png',phase_only_hologram)
+
+
True
+
+

In some cases, you may want to add a grating term to your hologram as you will display it on a spatial light modulator. +There are various reasons for that, but the most obvious is getting rid of zeroth-order reflections that are not modulated by your hologram. +In case you need it is as simple as below:

+
from odak.learn.wave import linear_grating
+grating = linear_grating(resolution[0],resolution[1],axis='y').to(phase_hologram.device)
+phase_only_hologram_w_grating = phase_hologram+calculate_phase(grating)
+
+

And let's save what we got from this step:

+
phase_only_hologram_w_grating = (phase_only_hologram_w_grating%slm_range)/(slm_range)*dynamic_range
+save_image('phase_only_hologram_w_grating.png',phase_only_hologram_w_grating)
+
+
True
+
+

See also

+

For more engineering notes, follow:

+ + + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/notes/using_metameric_loss/index.html b/notes/using_metameric_loss/index.html new file mode 100644 index 00000000..c7b0be24 --- /dev/null +++ b/notes/using_metameric_loss/index.html @@ -0,0 +1,1763 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Using metameric loss - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Using metameric loss

+ +

This engineering note will give you an idea about using the metameric perceptual loss in odak. +This note is compiled by David Walton. +If you have further questions regarding this note, please email David at david.walton.13@ucl.ac.uk.

+

Our metameric loss function works in a very similar way to built in loss functions in pytorch, such as torch.nn.MSELoss(). +However, it has a number of parameters which can be adjusted on creation (see the documentation). +Additionally, when calculating the loss a gaze location must be specified. For example:

+
loss_func = odak.learn.perception.MetamericLoss()
+loss = loss_func(my_image, gt_image, gaze=[0.7, 0.3])
+
+

The loss function caches some information, and performs most efficiently when repeatedly calculating losses for the same image size, with the same gaze location and foveation settings.

+

We recommend adjusting the parameters of the loss function to match your application. +Most importantly, please set the real_image_width and real_viewing_distance parameters to correspond to how your image will be displayed to the user. +The alpha parameter controls the intensity of the foveation effect. +You should only need to set alpha once - you can then adjust the width and viewing distance to achieve the same apparent foveation effect on a range of displays & viewing conditions. +Note that we assume the pixels in the displayed image are square, and derive the height from the image dimensions.

+

We also provide two baseline loss functions BlurLoss and MetamerMSELoss which function in much the same way.

+

At the present time the loss functions are implemented only for images displayed to a user on a flat 2D display (e.g. an LCD computer monitor). +Support for equirectangular 3D images is planned for the future.

+

See also

+

Visual perception

+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/objects.inv b/objects.inv new file mode 100644 index 00000000..48136224 Binary files /dev/null and b/objects.inv differ diff --git a/odak.png b/odak.png new file mode 100644 index 00000000..d335a299 Binary files /dev/null and b/odak.png differ diff --git a/odak/fit/index.html b/odak/fit/index.html new file mode 100644 index 00000000..0ae90311 --- /dev/null +++ b/odak/fit/index.html @@ -0,0 +1,2374 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.fit - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.fit

+ +
+ + + + +
+ +

odak.fit

+

Provides functions to fit models to a provided data. These functions could be best described as a catalog of machine learning models.

+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ gradient_descent_1d(input_data, ground_truth_data, parameters, function, gradient_function, loss_function, learning_rate=0.1, iteration_number=10) + +

+ + +
+ +

Vanilla Gradient Descent algorithm for 1D data.

+ + +

Parameters:

+
    +
  • + input_data + – +
    +
                One-dimensional input data.
    +
    +
    +
  • +
  • + ground_truth_data + (array) + – +
    +
                One-dimensional ground truth data.
    +
    +
    +
  • +
  • + parameters + – +
    +
                Parameters to be optimized.
    +
    +
    +
  • +
  • + function + – +
    +
                Function to estimate an output using the parameters.
    +
    +
    +
  • +
  • + gradient_function + (function) + – +
    +
                Function used in estimating gradient to update parameters at each iteration.
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                Learning rate.
    +
    +
    +
  • +
  • + iteration_number + – +
    +
                Iteration number.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +parameters ( array +) – +
    +

    Optimized parameters.

    +
    +
  • +
+ +
+ Source code in odak/fit/__init__.py +
32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def gradient_descent_1d(
+                        input_data,
+                        ground_truth_data,
+                        parameters,
+                        function,
+                        gradient_function,
+                        loss_function,
+                        learning_rate = 1e-1,
+                        iteration_number = 10
+                       ):
+    """
+    Vanilla Gradient Descent algorithm for 1D data.
+
+    Parameters
+    ----------
+    input_data        : numpy.array
+                        One-dimensional input data.
+    ground_truth_data : numpy.array
+                        One-dimensional ground truth data.
+    parameters        : numpy.array
+                        Parameters to be optimized.
+    function          : function
+                        Function to estimate an output using the parameters.
+    gradient_function : function
+                        Function used in estimating gradient to update parameters at each iteration.
+    learning_rate     : float
+                        Learning rate.
+    iteration_number  : int
+                        Iteration number.
+
+
+    Returns
+    -------
+    parameters        : numpy.array
+                        Optimized parameters.
+    """
+    t = tqdm(range(iteration_number))
+    for i in t:
+        gradient = np.zeros(parameters.shape[0])
+        for j in range(input_data.shape[0]):
+            x = input_data[j]
+            y = ground_truth_data[j]
+            gradient = gradient + gradient_function(x, y, function, parameters)
+        parameters = parameters - learning_rate * gradient / input_data.shape[0]
+        loss = loss_function(ground_truth_data, function(input_data, parameters))
+        description = 'Iteration number:{}, loss:{:0.4f}, parameters:{}'.format(i, loss, np.round(parameters, 2))
+        t.set_description(description)
+    return parameters
+
+
+
+ +
+ +
+ + +

+ least_square_1d(x, y) + +

+ + +
+ +

A function to fit a line to given x and y data (y=mx+n). Inspired from: https://mmas.github.io/least-squares-fitting-numpy-scipy

+ + +

Parameters:

+
    +
  • + x + – +
    +
         1D input data.
    +
    +
    +
  • +
  • + y + – +
    +
         1D output data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +parameters ( array +) – +
    +

    Parameters of m and n in a line (y=mx+n).

    +
    +
  • +
+ +
+ Source code in odak/fit/__init__.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
def least_square_1d(x, y):
+    """
+    A function to fit a line to given x and y data (y=mx+n). Inspired from: https://mmas.github.io/least-squares-fitting-numpy-scipy
+
+    Parameters
+    ----------
+    x          : numpy.array
+                 1D input data.
+    y          : numpy.array
+                 1D output data.
+
+    Returns
+    -------
+    parameters : numpy.array
+                 Parameters of m and n in a line (y=mx+n).
+    """
+    w = np.vstack([x, np.ones(x.shape[0])]).T
+    parameters = np.dot(np.linalg.inv(np.dot(w.T, w)), np.dot(w.T, y))
+    return parameters
+
+
+
+ +
+ +
+ + +

+ perceptron(x, y, learning_rate=0.1, iteration_number=100) + +

+ + +
+ +

A function to train a perceptron model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
               Input X-Y pairs [m x 2].
    +
    +
    +
  • +
  • + y + – +
    +
               Labels for the input data [m x 1]
    +
    +
    +
  • +
  • + learning_rate + – +
    +
               Learning rate.
    +
    +
    +
  • +
  • + iteration_number + (int, default: + 100 +) + – +
    +
               Iteration number.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +weights ( array +) – +
    +

    Trained weights of our model [3 x 1].

    +
    +
  • +
+ +
+ Source code in odak/fit/__init__.py +
def perceptron(x, y, learning_rate = 0.1, iteration_number = 100):
+    """
+    A function to train a perceptron model.
+
+    Parameters
+    ----------
+    x                : numpy.array
+                       Input X-Y pairs [m x 2].
+    y                : numpy.array
+                       Labels for the input data [m x 1]
+    learning_rate    : float
+                       Learning rate.
+    iteration_number : int
+                       Iteration number.
+
+    Returns
+    -------
+    weights          : numpy.array
+                       Trained weights of our model [3 x 1].
+    """
+    weights = np.zeros((x.shape[1] + 1, 1))
+    t = tqdm(range(iteration_number))
+    for step in t:
+        unsuccessful = 0
+        for data_id in range(x.shape[0]):
+            x_i = np.insert(x[data_id], 0, 1).reshape(-1, 1)
+            y_i = y[data_id]
+            y_hat = threshold_linear_model(x_i, weights)
+            if y_hat - y_i != 0:
+                unsuccessful += 1
+                weights = weights + learning_rate * (y_i - y_hat) * x_i 
+            description = 'Unsuccessful count: {}/{}'.format(unsuccessful, x.shape[0])
+    return weights
+
+
+
+ +
+ +
+ + +

+ threshold_linear_model(x, w, threshold=0) + +

+ + +
+ +

A function for thresholding a linear model described with a dot product.

+ + +

Parameters:

+
    +
  • + x + – +
    +
               Input data [3 x 1].
    +
    +
    +
  • +
  • + w + – +
    +
               Weights [3 x 1].
    +
    +
    +
  • +
  • + threshold + – +
    +
               Value for thresholding.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( int +) – +
    +

    Estimated class of the input data. It could either be one or zero.

    +
    +
  • +
+ +
+ Source code in odak/fit/__init__.py +
def threshold_linear_model(x, w, threshold = 0):
+    """
+    A function for thresholding a linear model described with a dot product.
+
+    Parameters
+    ----------
+    x                : numpy.array
+                       Input data [3 x 1].
+    w                : numpy.array
+                       Weights [3 x 1].
+    threshold        : float
+                       Value for thresholding.
+
+    Returns
+    -------
+    result           : int
+                       Estimated class of the input data. It could either be one or zero.
+    """
+    value = np.dot(x.T, w)
+    result = 0
+    if value >= threshold:
+       result = 1
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_lensless/index.html b/odak/learn_lensless/index.html new file mode 100644 index 00000000..880c2ba8 --- /dev/null +++ b/odak/learn_lensless/index.html @@ -0,0 +1,3031 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.lensless - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.lensless

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ spec_track + + +

+ + +
+

+ Bases: Module

+ + +

The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Akşit. "SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging." In SIGGRAPH Asia 2024 Posters (SA Posters '24).

+ + +

Parameters:

+
    +
  • + reduction + (str, default: + 'sum' +) + – +
    +
        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    +
    +
    +
  • +
  • + device + – +
    +
        Device to run the model on. Default is CPU.
    +
    +
    +
  • +
+ + + + + + +
+ Source code in odak/learn/lensless/models.py +
class spec_track(nn.Module):
+    """
+    The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Akşit. "SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging." In SIGGRAPH Asia 2024 Posters (SA Posters '24).
+
+    Parameters
+    ----------
+    reduction : str
+                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
+    device    : torch.device
+                Device to run the model on. Default is CPU.
+    """
+    def __init__(
+                 self,
+                 reduction = 'sum',
+                 device = torch.device('cpu')
+                ):
+        super(spec_track, self).__init__()
+        self.device = device
+        self.init_layers()
+        self.reduction = reduction
+        self.l2 = torch.nn.MSELoss(reduction = self.reduction)
+        self.l1 = torch.nn.L1Loss(reduction = self.reduction)
+        self.train_history = []
+        self.validation_history = []
+
+
+    def init_layers(self):
+        """
+        Initialize the layers of the network.
+        """
+        # Convolutional layers with batch normalization and pooling
+        self.network = nn.Sequential(OrderedDict([
+            ('conv1', nn.Conv2d(5, 32, kernel_size=3, padding=1)),
+            ('bn1', nn.BatchNorm2d(32)),
+            ('relu1', nn.ReLU()),
+            ('pool1', nn.MaxPool2d(kernel_size=3)),
+
+            ('conv2', nn.Conv2d(32, 64, kernel_size=5, padding=1)),
+            ('bn2', nn.BatchNorm2d(64)),
+            ('relu2', nn.ReLU()),
+            ('pool2', nn.MaxPool2d(kernel_size=3)),
+
+            ('conv3', nn.Conv2d(64, 128, kernel_size=7, padding=1)),
+            ('bn3', nn.BatchNorm2d(128)),
+            ('relu3', nn.ReLU()),
+            ('pool3', nn.MaxPool2d(kernel_size=3)),
+
+            ('flatten', nn.Flatten()),
+
+            ('fc1', nn.Linear(6400, 2048)),
+            ('fc_bn1', nn.BatchNorm1d(2048)),
+            ('relu_fc1', nn.ReLU()),
+
+            ('fc2', nn.Linear(2048, 1024)),
+            ('fc_bn2', nn.BatchNorm1d(1024)),
+            ('relu_fc2', nn.ReLU()),
+
+            ('fc3', nn.Linear(1024, 512)),
+            ('fc_bn3', nn.BatchNorm1d(512)),
+            ('relu_fc3', nn.ReLU()),
+
+            ('fc4', nn.Linear(512, 128)),
+            ('fc_bn4', nn.BatchNorm1d(128)),
+            ('relu_fc4', nn.ReLU()),
+
+            ('fc5', nn.Linear(128, 3))
+        ])).to(self.device)
+
+
+    def forward(self, x):
+        """
+        Forward pass of the network.
+
+        Parameters
+        ----------
+        x : torch.Tensor
+            Input tensor.
+
+        Returns
+        -------
+        torch.Tensor
+            Output tensor.
+        """
+        return self.network(x)
+
+
+    def evaluate(self, input_data, ground_truth, weights = [100., 1.]):
+        """
+        Evaluate the model's performance.
+
+        Parameters
+        ----------
+        input_data    : torch.Tensor
+                        Predicted data from the model.
+        ground_truth  : torch.Tensor
+                        Ground truth data.
+        weights       : list
+                        Weights for L2 and L1 losses. Default is [100., 1.].
+
+        Returns
+        -------
+        torch.Tensor
+            Combined weighted loss.
+        """
+        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
+        return loss
+
+
+    def fit(self, trainloader, testloader, number_of_epochs=100, learning_rate=1e-5, weight_decay=1e-5, directory='./output'):
+        """
+        Train the model.
+
+        Parameters
+        ----------
+        trainloader      : torch.utils.data.DataLoader
+                           Training data loader.
+        testloader       : torch.utils.data.DataLoader
+                           Testing data loader.
+        number_of_epochs : int
+                           Number of epochs to train for. Default is 100.
+        learning_rate    : float
+                           Learning rate for the optimizer. Default is 1e-5.
+        weight_decay     : float
+                           Weight decay for the optimizer. Default is 1e-5.
+        directory        : str
+                           Directory to save the model weights. Default is './output'.
+        """
+        makedirs(directory, exist_ok=True)
+        makedirs(join(directory, "log"), exist_ok=True)
+
+        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
+        best_val_loss = float('inf')
+
+        for epoch in range(number_of_epochs):
+            # Training phase
+            self.train()
+            train_loss = 0.0
+            train_batches = 0
+            train_pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{number_of_epochs} [Train]", leave=False, dynamic_ncols=True)
+
+            for batch, labels in train_pbar:
+                self.optimizer.zero_grad()
+                batch, labels = batch.to(self.device), labels.to(self.device)
+                predicts = torch.squeeze(self.forward(batch))
+                loss = self.evaluate(predicts, labels)
+                loss.backward()
+                self.optimizer.step()
+
+                train_loss += loss.item()
+                train_batches += 1
+                train_pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
+
+            avg_train_loss = train_loss / train_batches
+            self.train_history.append(avg_train_loss)
+
+            # Validation phase
+            self.eval()
+            val_loss = 0.0
+            val_batches = 0
+            val_pbar = tqdm(testloader, desc=f"Epoch {epoch+1}/{number_of_epochs} [Val]", leave=False, dynamic_ncols=True)
+
+            with torch.no_grad():
+                for batch, labels in val_pbar:
+                    batch, labels = batch.to(self.device), labels.to(self.device)
+                    predicts = torch.squeeze(self.forward(batch), dim=1)
+                    loss = self.evaluate(predicts, labels)
+
+                    val_loss += loss.item()
+                    val_batches += 1
+                    val_pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
+
+            avg_val_loss = val_loss / val_batches
+            self.validation_history.append(avg_val_loss)
+
+            # Print epoch summary
+            print(f"Epoch {epoch+1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
+
+            # Save best model
+            if avg_val_loss < best_val_loss:
+                best_val_loss = avg_val_loss
+                self.save_weights(join(directory, f"best_model_epoch_{epoch+1}.pt"))
+                print(f"Best model saved at epoch {epoch+1}")
+
+        # Save training history
+        torch.save(self.train_history, join(directory, "log", "train_log.pt"))
+        torch.save(self.validation_history, join(directory, "log", "validation_log.pt"))
+        print("Training completed. History saved.")
+
+
+    def save_weights(self, filename = './weights.pt'):
+        """
+        Save the current weights of the network to a file.
+
+        Parameters
+        ----------
+        filename : str
+                   Path to save the weights. Default is './weights.pt'.
+        """
+        torch.save(self.network.state_dict(), os.path.expanduser(filename))
+
+
+    def load_weights(self, filename = './weights.pt'):
+        """
+        Load weights for the network from a file.
+
+        Parameters
+        ----------
+        filename : str
+                   Path to load the weights from. Default is './weights.pt'.
+        """
+        self.network.load_state_dict(torch.load(os.path.expanduser(filename), weights_only = True))
+        self.network.eval()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ evaluate(input_data, ground_truth, weights=[100.0, 1.0]) + +

+ + +
+ +

Evaluate the model's performance.

+ + +

Parameters:

+
    +
  • + input_data + – +
    +
            Predicted data from the model.
    +
    +
    +
  • +
  • + ground_truth + – +
    +
            Ground truth data.
    +
    +
    +
  • +
  • + weights + – +
    +
            Weights for L2 and L1 losses. Default is [100., 1.].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • + Tensor + – +
    +

    Combined weighted loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/lensless/models.py +
def evaluate(self, input_data, ground_truth, weights = [100., 1.]):
+    """
+    Evaluate the model's performance.
+
+    Parameters
+    ----------
+    input_data    : torch.Tensor
+                    Predicted data from the model.
+    ground_truth  : torch.Tensor
+                    Ground truth data.
+    weights       : list
+                    Weights for L2 and L1 losses. Default is [100., 1.].
+
+    Returns
+    -------
+    torch.Tensor
+        Combined weighted loss.
+    """
+    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
+    return loss
+
+
+
+ +
+ +
+ + +

+ fit(trainloader, testloader, number_of_epochs=100, learning_rate=1e-05, weight_decay=1e-05, directory='./output') + +

+ + +
+ +

Train the model.

+ + +

Parameters:

+
    +
  • + trainloader + – +
    +
               Training data loader.
    +
    +
    +
  • +
  • + testloader + – +
    +
               Testing data loader.
    +
    +
    +
  • +
  • + number_of_epochs + (int, default: + 100 +) + – +
    +
               Number of epochs to train for. Default is 100.
    +
    +
    +
  • +
  • + learning_rate + – +
    +
               Learning rate for the optimizer. Default is 1e-5.
    +
    +
    +
  • +
  • + weight_decay + – +
    +
               Weight decay for the optimizer. Default is 1e-5.
    +
    +
    +
  • +
  • + directory + – +
    +
               Directory to save the model weights. Default is './output'.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/lensless/models.py +
def fit(self, trainloader, testloader, number_of_epochs=100, learning_rate=1e-5, weight_decay=1e-5, directory='./output'):
+    """
+    Train the model.
+
+    Parameters
+    ----------
+    trainloader      : torch.utils.data.DataLoader
+                       Training data loader.
+    testloader       : torch.utils.data.DataLoader
+                       Testing data loader.
+    number_of_epochs : int
+                       Number of epochs to train for. Default is 100.
+    learning_rate    : float
+                       Learning rate for the optimizer. Default is 1e-5.
+    weight_decay     : float
+                       Weight decay for the optimizer. Default is 1e-5.
+    directory        : str
+                       Directory to save the model weights. Default is './output'.
+    """
+    makedirs(directory, exist_ok=True)
+    makedirs(join(directory, "log"), exist_ok=True)
+
+    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
+    best_val_loss = float('inf')
+
+    for epoch in range(number_of_epochs):
+        # Training phase
+        self.train()
+        train_loss = 0.0
+        train_batches = 0
+        train_pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{number_of_epochs} [Train]", leave=False, dynamic_ncols=True)
+
+        for batch, labels in train_pbar:
+            self.optimizer.zero_grad()
+            batch, labels = batch.to(self.device), labels.to(self.device)
+            predicts = torch.squeeze(self.forward(batch))
+            loss = self.evaluate(predicts, labels)
+            loss.backward()
+            self.optimizer.step()
+
+            train_loss += loss.item()
+            train_batches += 1
+            train_pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
+
+        avg_train_loss = train_loss / train_batches
+        self.train_history.append(avg_train_loss)
+
+        # Validation phase
+        self.eval()
+        val_loss = 0.0
+        val_batches = 0
+        val_pbar = tqdm(testloader, desc=f"Epoch {epoch+1}/{number_of_epochs} [Val]", leave=False, dynamic_ncols=True)
+
+        with torch.no_grad():
+            for batch, labels in val_pbar:
+                batch, labels = batch.to(self.device), labels.to(self.device)
+                predicts = torch.squeeze(self.forward(batch), dim=1)
+                loss = self.evaluate(predicts, labels)
+
+                val_loss += loss.item()
+                val_batches += 1
+                val_pbar.set_postfix({'Loss': f"{loss.item():.4f}"})
+
+        avg_val_loss = val_loss / val_batches
+        self.validation_history.append(avg_val_loss)
+
+        # Print epoch summary
+        print(f"Epoch {epoch+1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
+
+        # Save best model
+        if avg_val_loss < best_val_loss:
+            best_val_loss = avg_val_loss
+            self.save_weights(join(directory, f"best_model_epoch_{epoch+1}.pt"))
+            print(f"Best model saved at epoch {epoch+1}")
+
+    # Save training history
+    torch.save(self.train_history, join(directory, "log", "train_log.pt"))
+    torch.save(self.validation_history, join(directory, "log", "validation_log.pt"))
+    print("Training completed. History saved.")
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the network.

+ + +

Parameters:

+
    +
  • + x + (Tensor) + – +
    +

    Input tensor.

    +
    +
  • +
+ + +

Returns:

+
    +
  • + Tensor + – +
    +

    Output tensor.

    +
    +
  • +
+ +
+ Source code in odak/learn/lensless/models.py +
79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def forward(self, x):
+    """
+    Forward pass of the network.
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Input tensor.
+
+    Returns
+    -------
+    torch.Tensor
+        Output tensor.
+    """
+    return self.network(x)
+
+
+
+ +
+ +
+ + +

+ init_layers() + +

+ + +
+ +

Initialize the layers of the network.

+ +
+ Source code in odak/learn/lensless/models.py +
36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
def init_layers(self):
+    """
+    Initialize the layers of the network.
+    """
+    # Convolutional layers with batch normalization and pooling
+    self.network = nn.Sequential(OrderedDict([
+        ('conv1', nn.Conv2d(5, 32, kernel_size=3, padding=1)),
+        ('bn1', nn.BatchNorm2d(32)),
+        ('relu1', nn.ReLU()),
+        ('pool1', nn.MaxPool2d(kernel_size=3)),
+
+        ('conv2', nn.Conv2d(32, 64, kernel_size=5, padding=1)),
+        ('bn2', nn.BatchNorm2d(64)),
+        ('relu2', nn.ReLU()),
+        ('pool2', nn.MaxPool2d(kernel_size=3)),
+
+        ('conv3', nn.Conv2d(64, 128, kernel_size=7, padding=1)),
+        ('bn3', nn.BatchNorm2d(128)),
+        ('relu3', nn.ReLU()),
+        ('pool3', nn.MaxPool2d(kernel_size=3)),
+
+        ('flatten', nn.Flatten()),
+
+        ('fc1', nn.Linear(6400, 2048)),
+        ('fc_bn1', nn.BatchNorm1d(2048)),
+        ('relu_fc1', nn.ReLU()),
+
+        ('fc2', nn.Linear(2048, 1024)),
+        ('fc_bn2', nn.BatchNorm1d(1024)),
+        ('relu_fc2', nn.ReLU()),
+
+        ('fc3', nn.Linear(1024, 512)),
+        ('fc_bn3', nn.BatchNorm1d(512)),
+        ('relu_fc3', nn.ReLU()),
+
+        ('fc4', nn.Linear(512, 128)),
+        ('fc_bn4', nn.BatchNorm1d(128)),
+        ('relu_fc4', nn.ReLU()),
+
+        ('fc5', nn.Linear(128, 3))
+    ])).to(self.device)
+
+
+
+ +
+ +
+ + +

+ load_weights(filename='./weights.pt') + +

+ + +
+ +

Load weights for the network from a file.

+ + +

Parameters:

+
    +
  • + filename + (str, default: + './weights.pt' +) + – +
    +
       Path to load the weights from. Default is './weights.pt'.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/lensless/models.py +
def load_weights(self, filename = './weights.pt'):
+    """
+    Load weights for the network from a file.
+
+    Parameters
+    ----------
+    filename : str
+               Path to load the weights from. Default is './weights.pt'.
+    """
+    self.network.load_state_dict(torch.load(os.path.expanduser(filename), weights_only = True))
+    self.network.eval()
+
+
+
+ +
+ +
+ + +

+ save_weights(filename='./weights.pt') + +

+ + +
+ +

Save the current weights of the network to a file.

+ + +

Parameters:

+
    +
  • + filename + (str, default: + './weights.pt' +) + – +
    +
       Path to save the weights. Default is './weights.pt'.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/lensless/models.py +
def save_weights(self, filename = './weights.pt'):
+    """
+    Save the current weights of the network to a file.
+
+    Parameters
+    ----------
+    filename : str
+               Path to save the weights. Default is './weights.pt'.
+    """
+    torch.save(self.network.state_dict(), os.path.expanduser(filename))
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_models/index.html b/odak/learn_models/index.html new file mode 100644 index 00000000..1a64a629 --- /dev/null +++ b/odak/learn_models/index.html @@ -0,0 +1,34584 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.models - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.models

+ +
+ + + + +
+ +

odak.learn.models

+

Provides necessary definitions for components used in machine learning and deep learning.

+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ channel_gate + + +

+ + +
+

+ Bases: Module

+ + +

Channel attention module with various pooling strategies. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class channel_gate(torch.nn.Module):
+    """
+    Channel attention module with various pooling strategies.
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max']
+                ):
+        """
+        Initializes the channel gate module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the intermediate layer.
+        pool_types      : list
+                          List of pooling operations to apply.
+        """
+        super().__init__()
+        self.gate_channels = gate_channels
+        hidden_channels = gate_channels // reduction_ratio
+        if hidden_channels == 0:
+            hidden_channels = 1
+        self.mlp = torch.nn.Sequential(
+                                       convolutional_block_attention.Flatten(),
+                                       torch.nn.Linear(gate_channels, hidden_channels),
+                                       torch.nn.ReLU(),
+                                       torch.nn.Linear(hidden_channels, gate_channels)
+                                      )
+        self.pool_types = pool_types
+
+
+    def forward(self, x):
+        """
+        Forward pass of the ChannelGate module.
+
+        Applies channel-wise attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the ChannelGate module.
+
+        Returns
+        -------
+        output       : torch.tensor
+                       Output tensor after applying channel attention.
+        """
+        channel_att_sum = None
+        for pool_type in self.pool_types:
+            if pool_type == 'avg':
+                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            elif pool_type == 'max':
+                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            channel_att_raw = self.mlp(pool)
+            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+        output = x * scale
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max']) + +

+ + +
+ +

Initializes the channel gate module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the intermediate layer.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max']
+            ):
+    """
+    Initializes the channel gate module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the intermediate layer.
+    pool_types      : list
+                      List of pooling operations to apply.
+    """
+    super().__init__()
+    self.gate_channels = gate_channels
+    hidden_channels = gate_channels // reduction_ratio
+    if hidden_channels == 0:
+        hidden_channels = 1
+    self.mlp = torch.nn.Sequential(
+                                   convolutional_block_attention.Flatten(),
+                                   torch.nn.Linear(gate_channels, hidden_channels),
+                                   torch.nn.ReLU(),
+                                   torch.nn.Linear(hidden_channels, gate_channels)
+                                  )
+    self.pool_types = pool_types
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the ChannelGate module.

+

Applies channel-wise attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the ChannelGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor after applying channel attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the ChannelGate module.
+
+    Applies channel-wise attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the ChannelGate module.
+
+    Returns
+    -------
+    output       : torch.tensor
+                   Output tensor after applying channel attention.
+    """
+    channel_att_sum = None
+    for pool_type in self.pool_types:
+        if pool_type == 'avg':
+            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        elif pool_type == 'max':
+            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        channel_att_raw = self.mlp(pool)
+        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+    output = x * scale
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolution_layer + + +

+ + +
+

+ Bases: Module

+ + +

A convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolution_layer(torch.nn.Module):
+    """
+    A convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 bias = False,
+                 stride = 1,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        layers = [
+            torch.nn.Conv2d(
+                            input_channels,
+                            output_channels,
+                            kernel_size = kernel_size,
+                            stride = stride,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           )
+        ]
+        if normalization:
+            layers.append(torch.nn.BatchNorm2d(output_channels))
+        if activation:
+            layers.append(activation)
+        self.model = torch.nn.Sequential(*layers)
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             bias = False,
+             stride = 1,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    layers = [
+        torch.nn.Conv2d(
+                        input_channels,
+                        output_channels,
+                        kernel_size = kernel_size,
+                        stride = stride,
+                        padding = kernel_size // 2,
+                        bias = bias
+                       )
+    ]
+    if normalization:
+        layers.append(torch.nn.BatchNorm2d(output_channels))
+    if activation:
+        layers.append(activation)
+    self.model = torch.nn.Sequential(*layers)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolutional_block_attention + + +

+ + +
+

+ Bases: Module

+ + +

Convolutional Block Attention Module (CBAM) class. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolutional_block_attention(torch.nn.Module):
+    """
+    Convolutional Block Attention Module (CBAM) class. 
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max'], 
+                 no_spatial = False
+                ):
+        """
+        Initializes the convolutional block attention module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the channel attention.
+        pool_types      : list
+                          List of pooling operations to apply for channel attention.
+        no_spatial      : bool
+                          If True, spatial attention is not applied.
+        """
+        super(convolutional_block_attention, self).__init__()
+        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+        self.no_spatial = no_spatial
+        if not no_spatial:
+            self.spatial_gate = spatial_gate()
+
+
+    class Flatten(torch.nn.Module):
+        """
+        Flattens the input tensor to a 2D matrix.
+        """
+        def forward(self, x):
+            return x.view(x.size(0), -1)
+
+
+    def forward(self, x):
+        """
+        Forward pass of the convolutional block attention module.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the CBAM module.
+
+        Returns
+        -------
+        x_out        : torch.tensor
+                       Output tensor after applying channel and spatial attention.
+        """
+        x_out = self.channel_gate(x)
+        if not self.no_spatial:
+            x_out = self.spatial_gate(x_out)
+        return x_out
+
+
+ + + +
+ + + + + + + + +
+ + + +

+ Flatten + + +

+ + +
+

+ Bases: Module

+ + +

Flattens the input tensor to a 2D matrix.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class Flatten(torch.nn.Module):
+    """
+    Flattens the input tensor to a 2D matrix.
+    """
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False) + +

+ + +
+ +

Initializes the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the channel attention.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply for channel attention.
    +
    +
    +
  • +
  • + no_spatial + – +
    +
              If True, spatial attention is not applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max'], 
+             no_spatial = False
+            ):
+    """
+    Initializes the convolutional block attention module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the channel attention.
+    pool_types      : list
+                      List of pooling operations to apply for channel attention.
+    no_spatial      : bool
+                      If True, spatial attention is not applied.
+    """
+    super(convolutional_block_attention, self).__init__()
+    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+    self.no_spatial = no_spatial
+    if not no_spatial:
+        self.spatial_gate = spatial_gate()
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the CBAM module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +x_out ( tensor +) – +
    +

    Output tensor after applying channel and spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the convolutional block attention module.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the CBAM module.
+
+    Returns
+    -------
+    x_out        : torch.tensor
+                   Output tensor after applying channel and spatial attention.
+    """
+    x_out = self.channel_gate(x)
+    if not self.no_spatial:
+        x_out = self.spatial_gate(x_out)
+    return x_out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ double_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A double convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class double_convolution(torch.nn.Module):
+    """
+    A double convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = None,
+                 output_channels = 2,
+                 kernel_size = 3, 
+                 bias = False,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        Double convolution model.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of channels in the hidden layer between two convolutions.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        if isinstance(mid_channels, type(None)):
+            mid_channels = output_channels
+        self.activation = activation
+        self.model = torch.nn.Sequential(
+                                         convolution_layer(
+                                                           input_channels = input_channels,
+                                                           output_channels = mid_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          ),
+                                         convolution_layer(
+                                                           input_channels = mid_channels,
+                                                           output_channels = output_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          )
+                                        )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

Double convolution model.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of channels in the hidden layer between two convolutions.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = None,
+             output_channels = 2,
+             kernel_size = 3, 
+             bias = False,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    Double convolution model.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of channels in the hidden layer between two convolutions.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    if isinstance(mid_channels, type(None)):
+        mid_channels = output_channels
+    self.activation = activation
+    self.model = torch.nn.Sequential(
+                                     convolution_layer(
+                                                       input_channels = input_channels,
+                                                       output_channels = mid_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      ),
+                                     convolution_layer(
+                                                       input_channels = mid_channels,
+                                                       output_channels = output_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      )
+                                    )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ downsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

A downscaling component followed by a double convolution.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class downsample_layer(torch.nn.Module):
+    """
+    A downscaling component followed by a double convolution.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.maxpool_conv = torch.nn.Sequential(
+                                                torch.nn.MaxPool2d(2),
+                                                double_convolution(
+                                                                   input_channels = input_channels,
+                                                                   mid_channels = output_channels,
+                                                                   output_channels = output_channels,
+                                                                   kernel_size = kernel_size,
+                                                                   bias = bias,
+                                                                   activation = activation
+                                                                  )
+                                               )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x              : torch.tensor
+                         First input data.
+
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.maxpool_conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.maxpool_conv = torch.nn.Sequential(
+                                            torch.nn.MaxPool2d(2),
+                                            double_convolution(
+                                                               input_channels = input_channels,
+                                                               mid_channels = output_channels,
+                                                               output_channels = output_channels,
+                                                               kernel_size = kernel_size,
+                                                               bias = bias,
+                                                               activation = activation
+                                                              )
+                                           )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
             First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x              : torch.tensor
+                     First input data.
+
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.maxpool_conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_feature_module + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies them to another input tensor via learned transformations.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_feature_module(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies them to another input tensor via learned transformations.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 mid_channels,
+                 output_channels,
+                 kernel_size,
+                 bias = False,
+                 normalization = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels  : int
+                          Number of mid channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.transformations_1 = global_transformations(input_channels, output_channels)
+        self.global_features_1 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.global_features_2 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        global_tensor_1 = self.transformations_1(x1, x2)
+        y1 = self.global_features_1(global_tensor_1)
+        y2 = self.global_features_2(y1)
+        global_tensor_2 = self.transformations_2(y1, y2)
+        return global_tensor_2
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of mid channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             mid_channels,
+             output_channels,
+             kernel_size,
+             bias = False,
+             normalization = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels  : int
+                      Number of mid channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.transformations_1 = global_transformations(input_channels, output_channels)
+    self.global_features_1 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.global_features_2 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    global_tensor_1 = self.transformations_1(x1, x2)
+    y1 = self.global_features_1(global_tensor_1)
+    y2 = self.global_features_2(y1)
+    global_tensor_2 = self.transformations_2(y1, y2)
+    return global_tensor_2
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_transformations + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies learned transformations to another input tensor.

+

This implementation is adapted from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Reference: +J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_transformations(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies learned transformations to another input tensor.
+
+    This implementation is adapted from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Reference:
+    J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        """
+        super().__init__()
+        self.global_feature_1 = torch.nn.Sequential(
+            torch.nn.Linear(input_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True),
+        )
+        self.global_feature_2 = torch.nn.Sequential(
+            torch.nn.Linear(output_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True)
+        )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        y = torch.mean(x2, dim = (2, 3))
+        y1 = self.global_feature_1(y)
+        y2 = self.global_feature_2(y1)
+        y1 = y1.unsqueeze(2).unsqueeze(3)
+        y2 = y2.unsqueeze(2).unsqueeze(3)
+        result = x1 * y1 + y2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    """
+    super().__init__()
+    self.global_feature_1 = torch.nn.Sequential(
+        torch.nn.Linear(input_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True),
+    )
+    self.global_feature_2 = torch.nn.Sequential(
+        torch.nn.Linear(output_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True)
+    )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    y = torch.mean(x2, dim = (2, 3))
+    y1 = self.global_feature_1(y)
+    y2 = self.global_feature_2(y1)
+    y1 = y1.unsqueeze(2).unsqueeze(3)
+    y2 = y2.unsqueeze(2).unsqueeze(3)
+    result = x1 * y1 + y2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ multi_layer_perceptron + + +

+ + +
+

+ Bases: Module

+ + +

A multi-layer perceptron model.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
class multi_layer_perceptron(torch.nn.Module):
+    """
+    A multi-layer perceptron model.
+    """
+
+    def __init__(self,
+                 dimensions,
+                 activation = torch.nn.ReLU(),
+                 bias = False,
+                 model_type = 'conventional',
+                 siren_multiplier = 1.,
+                 input_multiplier = None
+                ):
+        """
+        Parameters
+        ----------
+        dimensions        : list
+                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+        activation        : torch.nn
+                            Nonlinear activation function.
+                            Default is `torch.nn.ReLU()`.
+        bias              : bool
+                            If set to True, linear layers will include biases.
+        siren_multiplier  : float
+                            When using `SIREN` model type, this parameter functions as a hyperparameter.
+                            The original SIREN work uses 30.
+                            You can bypass this parameter by providing input that are not normalized and larger then one.
+        input_multiplier  : float
+                            Initial value of the input multiplier before the very first layer.
+        model_type        : str
+                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                            `conventional` refers to a standard multi layer perceptron.
+                            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+        """
+        super(multi_layer_perceptron, self).__init__()
+        self.activation = activation
+        self.bias = bias
+        self.model_type = model_type
+        self.layers = torch.nn.ModuleList()
+        self.siren_multiplier = siren_multiplier
+        self.dimensions = dimensions
+        for i in range(len(self.dimensions) - 1):
+            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+        if not isinstance(input_multiplier, type(None)):
+            self.input_multiplier = torch.nn.ParameterList()
+            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+        if self.model_type == 'FILM SIREN':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+        if self.model_type == 'Gaussian':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        if hasattr(self, 'input_multiplier'):
+            result = x * self.input_multiplier[0]
+        else:
+            result = x
+        for layer_id, layer in enumerate(self.layers[:-1]):
+            result = layer(result)
+            if self.model_type == 'conventional':
+                result = self.activation(result)
+            elif self.model_type == 'swish':
+                resutl = swish(result)
+            elif self.model_type == 'SIREN':
+                result = torch.sin(result * self.siren_multiplier)
+            elif self.model_type == 'FILM SIREN':
+                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+            elif self.model_type == 'Gaussian': 
+                result = gaussian(result, self.alpha[layer_id][0])
+        result = self.layers[-1](result)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None) + +

+ + +
+ + + +

Parameters:

+
    +
  • + dimensions + – +
    +
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
    +
    +
    +
  • +
  • + activation + – +
    +
                Nonlinear activation function.
    +            Default is `torch.nn.ReLU()`.
    +
    +
    +
  • +
  • + bias + – +
    +
                If set to True, linear layers will include biases.
    +
    +
    +
  • +
  • + siren_multiplier + – +
    +
                When using `SIREN` model type, this parameter functions as a hyperparameter.
    +            The original SIREN work uses 30.
    +            You can bypass this parameter by providing input that are not normalized and larger then one.
    +
    +
    +
  • +
  • + input_multiplier + – +
    +
                Initial value of the input multiplier before the very first layer.
    +
    +
    +
  • +
  • + model_type + – +
    +
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
    +            `conventional` refers to a standard multi layer perceptron.
    +            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
    +            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
    +            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
    +            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
def __init__(self,
+             dimensions,
+             activation = torch.nn.ReLU(),
+             bias = False,
+             model_type = 'conventional',
+             siren_multiplier = 1.,
+             input_multiplier = None
+            ):
+    """
+    Parameters
+    ----------
+    dimensions        : list
+                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+    activation        : torch.nn
+                        Nonlinear activation function.
+                        Default is `torch.nn.ReLU()`.
+    bias              : bool
+                        If set to True, linear layers will include biases.
+    siren_multiplier  : float
+                        When using `SIREN` model type, this parameter functions as a hyperparameter.
+                        The original SIREN work uses 30.
+                        You can bypass this parameter by providing input that are not normalized and larger then one.
+    input_multiplier  : float
+                        Initial value of the input multiplier before the very first layer.
+    model_type        : str
+                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                        `conventional` refers to a standard multi layer perceptron.
+                        For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                        For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+    """
+    super(multi_layer_perceptron, self).__init__()
+    self.activation = activation
+    self.bias = bias
+    self.model_type = model_type
+    self.layers = torch.nn.ModuleList()
+    self.siren_multiplier = siren_multiplier
+    self.dimensions = dimensions
+    for i in range(len(self.dimensions) - 1):
+        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+    if not isinstance(input_multiplier, type(None)):
+        self.input_multiplier = torch.nn.ParameterList()
+        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+    if self.model_type == 'FILM SIREN':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+    if self.model_type == 'Gaussian':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    if hasattr(self, 'input_multiplier'):
+        result = x * self.input_multiplier[0]
+    else:
+        result = x
+    for layer_id, layer in enumerate(self.layers[:-1]):
+        result = layer(result)
+        if self.model_type == 'conventional':
+            result = self.activation(result)
+        elif self.model_type == 'swish':
+            resutl = swish(result)
+        elif self.model_type == 'SIREN':
+            result = torch.sin(result * self.siren_multiplier)
+        elif self.model_type == 'FILM SIREN':
+            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+        elif self.model_type == 'Gaussian': 
+            result = gaussian(result, self.alpha[layer_id][0])
+    result = self.layers[-1](result)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ non_local_layer + + +

+ + +
+

+ Bases: Module

+ + +

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class non_local_layer(torch.nn.Module):
+    """
+    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)
+    """
+    def __init__(
+                 self,
+                 input_channels = 1024,
+                 bottleneck_channels = 512,
+                 kernel_size = 1,
+                 bias = False,
+                ):
+        """
+
+        Parameters
+        ----------
+        input_channels      : int
+                              Number of input channels.
+        bottleneck_channels : int
+                              Number of middle channels.
+        kernel_size         : int
+                              Kernel size.
+        bias                : bool 
+                              Set to True to let convolutional layers have bias term.
+        """
+        super(non_local_layer, self).__init__()
+        self.input_channels = input_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.g = torch.nn.Conv2d(
+                                 self.input_channels, 
+                                 self.bottleneck_channels,
+                                 kernel_size = kernel_size,
+                                 padding = kernel_size // 2,
+                                 bias = bias
+                                )
+        self.W_z = torch.nn.Sequential(
+                                       torch.nn.Conv2d(
+                                                       self.bottleneck_channels,
+                                                       self.input_channels, 
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       padding = kernel_size // 2
+                                                      ),
+                                       torch.nn.BatchNorm2d(self.input_channels)
+                                      )
+        torch.nn.init.constant_(self.W_z[1].weight, 0)   
+        torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+    def forward(self, x):
+        """
+        Forward model [zi = Wzyi + xi]
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          First input data.                       
+
+
+        Returns
+        ----------
+        z               : torch.tensor
+                          Estimated output.
+        """
+        batch_size, channels, height, width = x.size()
+        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+        attn = torch.nn.functional.softmax(attn, dim=-1)
+        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+        W_y = self.W_z(y)
+        z = W_y + x
+        return z
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + input_channels + – +
    +
                  Number of input channels.
    +
    +
    +
  • +
  • + bottleneck_channels + (int, default: + 512 +) + – +
    +
                  Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
                  Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
                  Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 1024,
+             bottleneck_channels = 512,
+             kernel_size = 1,
+             bias = False,
+            ):
+    """
+
+    Parameters
+    ----------
+    input_channels      : int
+                          Number of input channels.
+    bottleneck_channels : int
+                          Number of middle channels.
+    kernel_size         : int
+                          Kernel size.
+    bias                : bool 
+                          Set to True to let convolutional layers have bias term.
+    """
+    super(non_local_layer, self).__init__()
+    self.input_channels = input_channels
+    self.bottleneck_channels = bottleneck_channels
+    self.g = torch.nn.Conv2d(
+                             self.input_channels, 
+                             self.bottleneck_channels,
+                             kernel_size = kernel_size,
+                             padding = kernel_size // 2,
+                             bias = bias
+                            )
+    self.W_z = torch.nn.Sequential(
+                                   torch.nn.Conv2d(
+                                                   self.bottleneck_channels,
+                                                   self.input_channels, 
+                                                   kernel_size = kernel_size,
+                                                   bias = bias,
+                                                   padding = kernel_size // 2
+                                                  ),
+                                   torch.nn.BatchNorm2d(self.input_channels)
+                                  )
+    torch.nn.init.constant_(self.W_z[1].weight, 0)   
+    torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model [zi = Wzyi + xi]

+ + +

Parameters:

+
    +
  • + x + – +
    +
              First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +z ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model [zi = Wzyi + xi]
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      First input data.                       
+
+
+    Returns
+    ----------
+    z               : torch.tensor
+                      Estimated output.
+    """
+    batch_size, channels, height, width = x.size()
+    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+    attn = torch.nn.functional.softmax(attn, dim=-1)
+    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+    W_y = self.W_z(y)
+    z = W_y + x
+    return z
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ normalization + + +

+ + +
+

+ Bases: Module

+ + +

A normalization layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class normalization(torch.nn.Module):
+    """
+    A normalization layer.
+    """
+    def __init__(
+                 self,
+                 dim = 1,
+                ):
+        """
+        Normalization layer.
+
+
+        Parameters
+        ----------
+        dim             : int
+                          Dimension (axis) to normalize.
+        """
+        super().__init__()
+        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+        mean = torch.mean(x, dim = 1, keepdim = True)
+        result =  (x - mean) * (var + eps).rsqrt() * self.k
+        return result 
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dim=1) + +

+ + +
+ +

Normalization layer.

+ + +

Parameters:

+
    +
  • + dim + – +
    +
              Dimension (axis) to normalize.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             dim = 1,
+            ):
+    """
+    Normalization layer.
+
+
+    Parameters
+    ----------
+    dim             : int
+                      Dimension (axis) to normalize.
+    """
+    super().__init__()
+    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+    mean = torch.mean(x, dim = 1, keepdim = True)
+    result =  (x - mean) * (var + eps).rsqrt() * self.k
+    return result 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ positional_encoder + + +

+ + +
+

+ Bases: Module

+ + +

A positional encoder module.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class positional_encoder(torch.nn.Module):
+    """
+    A positional encoder module.
+    """
+
+    def __init__(self, L):
+        """
+        A positional encoder module.
+
+        Parameters
+        ----------
+        L                   : int
+                              Positional encoding level.
+        """
+        super(positional_encoder, self).__init__()
+        self.L = L
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          Input data.
+
+        Returns
+        ----------
+        result          : torch.tensor
+                          Result of the forward operation
+        """
+        B, C = x.shape
+        x = x.view(B, C, 1)
+        results = [x]
+        for i in range(1, self.L + 1):
+            freq = (2 ** i) * math.pi
+            cos_x = torch.cos(freq * x)
+            sin_x = torch.sin(freq * x)
+            results.append(cos_x)
+            results.append(sin_x)
+        results = torch.cat(results, dim=2)
+        results = results.permute(0, 2, 1)
+        results = results.reshape(B, -1)
+        return results
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(L) + +

+ + +
+ +

A positional encoder module.

+ + +

Parameters:

+
    +
  • + L + – +
    +
                  Positional encoding level.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(self, L):
+    """
+    A positional encoder module.
+
+    Parameters
+    ----------
+    L                   : int
+                          Positional encoding level.
+    """
+    super(positional_encoder, self).__init__()
+    self.L = L
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
              Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      Input data.
+
+    Returns
+    ----------
+    result          : torch.tensor
+                      Result of the forward operation
+    """
+    B, C = x.shape
+    x = x.view(B, C, 1)
+    results = [x]
+    for i in range(1, self.L + 1):
+        freq = (2 ** i) * math.pi
+        cos_x = torch.cos(freq * x)
+        sin_x = torch.sin(freq * x)
+        results.append(cos_x)
+        results.append(sin_x)
+    results = torch.cat(results, dim=2)
+    results = results.permute(0, 2, 1)
+    results = results.reshape(B, -1)
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_attention_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual block with an attention layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_attention_layer(torch.nn.Module):
+    """
+    A residual block with an attention layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 1,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        An attention layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int or optioal
+                          Number of input channels.
+        output_channels : int or optional
+                          Number of middle channels.
+        kernel_size     : int or optional
+                          Kernel size.
+        bias            : bool or optional
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn or optional
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution0 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.convolution1 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.final_layer = torch.nn.Sequential(
+                                               self.activation,
+                                               torch.nn.Conv2d(
+                                                               output_channels,
+                                                               output_channels,
+                                                               kernel_size = kernel_size,
+                                                               padding = kernel_size // 2,
+                                                               bias = bias
+                                                              )
+                                              )
+
+
+    def forward(self, x0, x1):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x0             : torch.tensor
+                         First input data.
+
+        x1             : torch.tensor
+                         Seconnd input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        y0 = self.convolution0(x0)
+        y1 = self.convolution1(x1)
+        y2 = torch.add(y0, y1)
+        result = self.final_layer(y2) * x0
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

An attention layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int or optional, default: + 2 +) + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 1,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    An attention layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int or optioal
+                      Number of input channels.
+    output_channels : int or optional
+                      Number of middle channels.
+    kernel_size     : int or optional
+                      Kernel size.
+    bias            : bool or optional
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn or optional
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution0 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.convolution1 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.final_layer = torch.nn.Sequential(
+                                           self.activation,
+                                           torch.nn.Conv2d(
+                                                           output_channels,
+                                                           output_channels,
+                                                           kernel_size = kernel_size,
+                                                           padding = kernel_size // 2,
+                                                           bias = bias
+                                                          )
+                                          )
+
+
+
+ +
+ +
+ + +

+ forward(x0, x1) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x0 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x1 + – +
    +
             Seconnd input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x0, x1):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x0             : torch.tensor
+                     First input data.
+
+    x1             : torch.tensor
+                     Seconnd input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    y0 = self.convolution0(x0)
+    y1 = self.convolution1(x1)
+    y2 = torch.add(y0, y1)
+    result = self.final_layer(y2) * x0
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_layer(torch.nn.Module):
+    """
+    A residual layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = 16,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of middle channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution = double_convolution(
+                                              input_channels,
+                                              mid_channels = mid_channels,
+                                              output_channels = input_channels,
+                                              kernel_size = kernel_size,
+                                              bias = bias,
+                                              activation = activation
+                                             )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        x0 = self.convolution(x)
+        return x + x0
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = 16,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of middle channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution = double_convolution(
+                                          input_channels,
+                                          mid_channels = mid_channels,
+                                          output_channels = input_channels,
+                                          kernel_size = kernel_size,
+                                          bias = bias,
+                                          activation = activation
+                                         )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    x0 = self.convolution(x)
+    return x + x0
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatial_gate + + +

+ + +
+

+ Bases: Module

+ + +

Spatial attention module that applies a convolution layer after channel pooling. +This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class spatial_gate(torch.nn.Module):
+    """
+    Spatial attention module that applies a convolution layer after channel pooling.
+    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
+    """
+    def __init__(self):
+        """
+        Initializes the spatial gate module.
+        """
+        super().__init__()
+        kernel_size = 7
+        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+    def channel_pool(self, x):
+        """
+        Applies max and average pooling on the channels.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input tensor.
+
+        Returns
+        -------
+        output        : torch.tensor
+                        Output tensor.
+        """
+        max_pool = torch.max(x, 1)[0].unsqueeze(1)
+        avg_pool = torch.mean(x, 1).unsqueeze(1)
+        output = torch.cat((max_pool, avg_pool), dim=1)
+        return output
+
+
+    def forward(self, x):
+        """
+        Forward pass of the SpatialGate module.
+
+        Applies spatial attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the SpatialGate module.
+
+        Returns
+        -------
+        scaled_x     : torch.tensor
+                       Output tensor after applying spatial attention.
+        """
+        x_compress = self.channel_pool(x)
+        x_out = self.spatial(x_compress)
+        scale = torch.sigmoid(x_out)
+        scaled_x = x * scale
+        return scaled_x
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the spatial gate module.

+ +
+ Source code in odak/learn/models/components.py +
def __init__(self):
+    """
+    Initializes the spatial gate module.
+    """
+    super().__init__()
+    kernel_size = 7
+    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+
+ +
+ +
+ + +

+ channel_pool(x) + +

+ + +
+ +

Applies max and average pooling on the channels.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def channel_pool(self, x):
+    """
+    Applies max and average pooling on the channels.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input tensor.
+
+    Returns
+    -------
+    output        : torch.tensor
+                    Output tensor.
+    """
+    max_pool = torch.max(x, 1)[0].unsqueeze(1)
+    avg_pool = torch.mean(x, 1).unsqueeze(1)
+    output = torch.cat((max_pool, avg_pool), dim=1)
+    return output
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the SpatialGate module.

+

Applies spatial attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the SpatialGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +scaled_x ( tensor +) – +
    +

    Output tensor after applying spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the SpatialGate module.
+
+    Applies spatial attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the SpatialGate module.
+
+    Returns
+    -------
+    scaled_x     : torch.tensor
+                   Output tensor after applying spatial attention.
+    """
+    x_compress = self.channel_pool(x)
+    x_out = self.spatial(x_compress)
+    scale = torch.sigmoid(x_out)
+    scaled_x = x * scale
+    return scaled_x
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive convolution layer.

+ + +
+ References +

C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." +C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." +C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_convolution(torch.nn.Module):
+    """
+    A spatially adaptive convolution layer.
+
+    References
+    ----------
+
+    C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions."
+    C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation."
+    C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn.Module
+                          Activation function to apply. If None, no activation is applied.
+        """
+        super(spatially_adaptive_convolution, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        sa_output          : torch.tensor
+                            Estimated output tensor.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Resize weight to match the input channels and kernel size
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply spatially varying kernels
+        sv_feature = input_feature * sv_kernel
+
+        # Perform matrix multiplication
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+        return sa_output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Activation function to apply. If None, no activation is applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn.Module
+                      Activation function to apply. If None, no activation is applied.
+    """
+    super(spatially_adaptive_convolution, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sa_output ( tensor +) – +
    +

    Estimated output tensor. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    sa_output          : torch.tensor
+                        Estimated output tensor.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Resize weight to match the input channels and kernel size
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply spatially varying kernels
+    sv_feature = input_feature * sv_kernel
+
+    # Perform matrix multiplication
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+    return sa_output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_module + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive module that combines learned spatially adaptive convolutions.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_module(torch.nn.Module):
+    """
+    A spatially adaptive module that combines learned spatially adaptive convolutions.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive module.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super(spatially_adaptive_module, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.weight_output_channels = self.output_channels - 1
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.weight_output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive module.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        output             : torch.tensor
+                            Combined output tensor from standard and spatially adaptive convolutions.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Apply sv_kernel to the input_feature
+        sv_feature = input_feature * sv_kernel
+
+        # Original spatially varying convolution output
+        sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                           1,
+                                                            1,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+        # Reshape weight for spatially adaptive convolution
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply si_kernel on sv convolution output
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+
+        # Combine the outputs and apply activation function
+        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive module.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive module.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super(spatially_adaptive_module, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.weight_output_channels = self.output_channels - 1
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.weight_output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Combined output tensor from standard and spatially adaptive convolutions. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive module.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    output             : torch.tensor
+                        Combined output tensor from standard and spatially adaptive convolutions.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Apply sv_kernel to the input_feature
+    sv_feature = input_feature * sv_kernel
+
+    # Original spatially varying convolution output
+    sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                       1,
+                                                        1,
+                                                        (x.size(-2) // self.stride),
+                                                        (x.size(-1) // self.stride)
+                                                       )
+
+    # Reshape weight for spatially adaptive convolution
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply si_kernel on sv convolution output
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+    # Combine the outputs and apply activation function
+    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_unet + + +

+ + +
+

+ Bases: Module

+ + +

Spatially varying U-Net model based on spatially adaptive convolution.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_adaptive_unet(torch.nn.Module):
+    """
+    Spatially varying U-Net model based on spatially adaptive convolution.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 depth=3,
+                 dimensions=8,
+                 input_channels=6,
+                 out_channels=6,
+                 kernel_size=3,
+                 bias=True,
+                 normalization=False,
+                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        out_channels   : int
+                         Number of output channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.out_channels = out_channels
+        self.inc = convolution_layer(
+                                     input_channels=input_channels,
+                                     output_channels=dimensions,
+                                     kernel_size=kernel_size,
+                                     bias=bias,
+                                     normalization=normalization,
+                                     activation=activation
+                                    )
+
+        self.encoder = torch.nn.ModuleList()
+        for i in range(self.depth + 1):  # Downsampling layers
+            down_in_channels = dimensions * (2 ** i)
+            down_out_channels = 2 * down_in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels=down_in_channels,
+                                                          mid_channels=down_in_channels,
+                                                          output_channels=down_in_channels,
+                                                          kernel_size=kernel_size,
+                                                          bias=bias,
+                                                          normalization=normalization,
+                                                          activation=activation
+                                                         )
+            sam = spatially_adaptive_module(
+                                            input_channels=down_in_channels,
+                                            output_channels=down_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                           )
+            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+        self.global_feature_module = torch.nn.ModuleList()
+        double_convolution_layer = double_convolution(
+                                                      input_channels=dimensions * (2 ** (depth + 1)),
+                                                      mid_channels=dimensions * (2 ** (depth + 1)),
+                                                      output_channels=dimensions * (2 ** (depth + 1)),
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        global_feature_layer = global_feature_module(
+                                                     input_channels=dimensions * (2 ** (depth + 1)),
+                                                     mid_channels=dimensions * (2 ** (depth + 1)),
+                                                     output_channels=dimensions * (2 ** (depth + 1)),
+                                                     kernel_size=kernel_size,
+                                                     bias=bias,
+                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                    )
+        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+        self.decoder = torch.nn.ModuleList()
+        for i in range(depth, -1, -1):
+            up_in_channels = dimensions * (2 ** (i + 1))
+            up_mid_channels = up_in_channels // 2
+            if i == 0:
+                up_out_channels = self.out_channels
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = torch.nn.Sequential(
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_mid_channels,
+                                      kernel_size=kernel_size,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=activation,
+                                     ),
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_out_channels,
+                                      kernel_size=1,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=None,
+                                     )
+                )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+            else:
+                up_out_channels = up_in_channels // 2
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = double_convolution(
+                                                input_channels=up_mid_channels,
+                                                mid_channels=up_mid_channels,
+                                                output_channels=up_out_channels,
+                                                kernel_size=kernel_size,
+                                                bias=bias,
+                                                normalization=normalization,
+                                                activation=activation,
+                                               )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, sv_kernel, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+
+        field     : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        target_field : torch.tensor
+                       Estimated output.
+                       Dimension: (1, 6, H, W)
+        """
+        x = self.inc(field)
+        downsampling_outputs = [x]
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer[0](downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+            downsampling_outputs.append(sam_output)
+        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+        downsampling_outputs.append(global_feature)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate(self.decoder):
+            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+            x_up = up_layer[1](x_up)
+        result = x_up
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 6 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + out_channels + – +
    +
             Number of output channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth=3,
+             dimensions=8,
+             input_channels=6,
+             out_channels=6,
+             kernel_size=3,
+             bias=True,
+             normalization=False,
+             activation=torch.nn.LeakyReLU(0.2, inplace=True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    out_channels   : int
+                     Number of output channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.out_channels = out_channels
+    self.inc = convolution_layer(
+                                 input_channels=input_channels,
+                                 output_channels=dimensions,
+                                 kernel_size=kernel_size,
+                                 bias=bias,
+                                 normalization=normalization,
+                                 activation=activation
+                                )
+
+    self.encoder = torch.nn.ModuleList()
+    for i in range(self.depth + 1):  # Downsampling layers
+        down_in_channels = dimensions * (2 ** i)
+        down_out_channels = 2 * down_in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels=down_in_channels,
+                                                      mid_channels=down_in_channels,
+                                                      output_channels=down_in_channels,
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        sam = spatially_adaptive_module(
+                                        input_channels=down_in_channels,
+                                        output_channels=down_out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                       )
+        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+    self.global_feature_module = torch.nn.ModuleList()
+    double_convolution_layer = double_convolution(
+                                                  input_channels=dimensions * (2 ** (depth + 1)),
+                                                  mid_channels=dimensions * (2 ** (depth + 1)),
+                                                  output_channels=dimensions * (2 ** (depth + 1)),
+                                                  kernel_size=kernel_size,
+                                                  bias=bias,
+                                                  normalization=normalization,
+                                                  activation=activation
+                                                 )
+    global_feature_layer = global_feature_module(
+                                                 input_channels=dimensions * (2 ** (depth + 1)),
+                                                 mid_channels=dimensions * (2 ** (depth + 1)),
+                                                 output_channels=dimensions * (2 ** (depth + 1)),
+                                                 kernel_size=kernel_size,
+                                                 bias=bias,
+                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                )
+    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+    self.decoder = torch.nn.ModuleList()
+    for i in range(depth, -1, -1):
+        up_in_channels = dimensions * (2 ** (i + 1))
+        up_mid_channels = up_in_channels // 2
+        if i == 0:
+            up_out_channels = self.out_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = torch.nn.Sequential(
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_mid_channels,
+                                  kernel_size=kernel_size,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=activation,
+                                 ),
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_out_channels,
+                                  kernel_size=1,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=None,
+                                 )
+            )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+        else:
+            up_out_channels = up_in_channels // 2
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels=up_mid_channels,
+                                            mid_channels=up_mid_channels,
+                                            output_channels=up_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            normalization=normalization,
+                                            activation=activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(sv_kernel, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + sv_kernel + (list of torch.tensor) + – +
    +
        Learned spatially varying kernels.
    +    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
    +    where C_i, H_i, and W_i represent the channel, height, and width
    +    of each feature at a certain scale.
    +
    +
    +
  • +
  • + field + – +
    +
        Input field data.
    +    Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +target_field ( tensor +) – +
    +

    Estimated output. +Dimension: (1, 6, H, W)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, sv_kernel, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+
+    field     : torch.tensor
+                Input field data.
+                Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    target_field : torch.tensor
+                   Estimated output.
+                   Dimension: (1, 6, H, W)
+    """
+    x = self.inc(field)
+    downsampling_outputs = [x]
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer[0](downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+        downsampling_outputs.append(sam_output)
+    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+    downsampling_outputs.append(global_feature)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate(self.decoder):
+        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+        x_up = up_layer[1](x_up)
+    result = x_up
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_varying_kernel_generation_model + + +

+ + +
+

+ Bases: Module

+ + +

Spatially_varying_kernel_generation_model revised from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Refer to: +J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_varying_kernel_generation_model(torch.nn.Module):
+    """
+    Spatially_varying_kernel_generation_model revised from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Refer to:
+    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.
+    """
+
+    def __init__(
+                 self,
+                 depth = 3,
+                 dimensions = 8,
+                 input_channels = 7,
+                 kernel_size = 3,
+                 bias = True,
+                 normalization = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.inc = convolution_layer(
+                                     input_channels = input_channels,
+                                     output_channels = dimensions,
+                                     kernel_size = kernel_size,
+                                     bias = bias,
+                                     normalization = normalization,
+                                     activation = activation
+                                    )
+        self.encoder = torch.nn.ModuleList()
+        for i in range(depth + 1):  # downsampling layers
+            if i == 0:
+                in_channels = dimensions * (2 ** i)
+                out_channels = dimensions * (2 ** i)
+            elif i == depth:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = dimensions * (2 ** (i - 1))
+            else:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = 2 * in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels = in_channels,
+                                                          mid_channels = in_channels,
+                                                          output_channels = out_channels,
+                                                          kernel_size = kernel_size,
+                                                          bias = bias,
+                                                          normalization = normalization,
+                                                          activation = activation
+                                                         )
+            self.encoder.append(pooling_layer)
+            self.encoder.append(double_convolution_layer)
+        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+        for i in range(depth, -1, -1):
+            if i == 1:
+                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+            else:
+                svf_in_channels = 2 ** (self.depth + i) + 1
+            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+            svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+            spatially_varying_kernel_generation = torch.nn.ModuleList()
+            for j in range(i, -1, -1):
+                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+                spatially_varying_kernel_generation.append(pooling_layer)
+            kernel_generation_block = torch.nn.Sequential(
+                torch.nn.Conv2d(
+                                in_channels = svf_in_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_out_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+            )
+            spatially_varying_kernel_generation.append(kernel_generation_block)
+            self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+        self.decoder = torch.nn.ModuleList()
+        global_feature_layer = global_feature_module(  # global feature layer
+                                                     input_channels = dimensions * (2 ** (depth - 1)),
+                                                     mid_channels = dimensions * (2 ** (depth - 1)),
+                                                     output_channels = dimensions * (2 ** (depth - 1)),
+                                                     kernel_size = kernel_size,
+                                                     bias = bias,
+                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                    )
+        self.decoder.append(global_feature_layer)
+        for i in range(depth, 0, -1):
+            if i == 2:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels
+                up_mid_channels = up_in_channels
+            elif i == 1:
+                up_in_channels = dimensions * 2
+                up_out_channels = dimensions
+                up_mid_channels = up_out_channels
+            else:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels // 2
+                up_mid_channels = up_in_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels = up_in_channels,
+                                                            output_channels = up_mid_channels,
+                                                            kernel_size = 2,
+                                                            stride = 2,
+                                                            bias = bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels = up_mid_channels,
+                                            output_channels = up_out_channels,
+                                            kernel_size = kernel_size,
+                                            bias = bias,
+                                            normalization = normalization,
+                                            activation = activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, focal_surface, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        focal_surface : torch.tensor
+                        Input focal surface data.
+                        Dimension: (1, 1, H, W)
+
+        field         : torch.tensor
+                        Input field data.
+                        Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+        """
+        x = self.inc(torch.cat((focal_surface, field), dim = 1))
+        downsampling_outputs = [focal_surface]
+        downsampling_outputs.append(x)
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        sv_kernels = []
+        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+            if i == 0:
+                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+                downsampling_outputs[-1] = global_feature
+                sv_feature = [global_feature, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                              sv_feature[3]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+            else:
+                x_up = up_layer[0](downsampling_outputs[-1],
+                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+                x_up = up_layer[1](x_up)
+                downsampling_outputs[-1] = x_up
+                sv_feature = [x_up, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                if i == 1:
+                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+        return sv_kernels
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 7 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth = 3,
+             dimensions = 8,
+             input_channels = 7,
+             kernel_size = 3,
+             bias = True,
+             normalization = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.inc = convolution_layer(
+                                 input_channels = input_channels,
+                                 output_channels = dimensions,
+                                 kernel_size = kernel_size,
+                                 bias = bias,
+                                 normalization = normalization,
+                                 activation = activation
+                                )
+    self.encoder = torch.nn.ModuleList()
+    for i in range(depth + 1):  # downsampling layers
+        if i == 0:
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** i)
+        elif i == depth:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = dimensions * (2 ** (i - 1))
+        else:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = 2 * in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels = in_channels,
+                                                      mid_channels = in_channels,
+                                                      output_channels = out_channels,
+                                                      kernel_size = kernel_size,
+                                                      bias = bias,
+                                                      normalization = normalization,
+                                                      activation = activation
+                                                     )
+        self.encoder.append(pooling_layer)
+        self.encoder.append(double_convolution_layer)
+    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+    for i in range(depth, -1, -1):
+        if i == 1:
+            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+        else:
+            svf_in_channels = 2 ** (self.depth + i) + 1
+        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+        svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+        spatially_varying_kernel_generation = torch.nn.ModuleList()
+        for j in range(i, -1, -1):
+            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+            spatially_varying_kernel_generation.append(pooling_layer)
+        kernel_generation_block = torch.nn.Sequential(
+            torch.nn.Conv2d(
+                            in_channels = svf_in_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_out_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+        )
+        spatially_varying_kernel_generation.append(kernel_generation_block)
+        self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+    self.decoder = torch.nn.ModuleList()
+    global_feature_layer = global_feature_module(  # global feature layer
+                                                 input_channels = dimensions * (2 ** (depth - 1)),
+                                                 mid_channels = dimensions * (2 ** (depth - 1)),
+                                                 output_channels = dimensions * (2 ** (depth - 1)),
+                                                 kernel_size = kernel_size,
+                                                 bias = bias,
+                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                )
+    self.decoder.append(global_feature_layer)
+    for i in range(depth, 0, -1):
+        if i == 2:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels
+            up_mid_channels = up_in_channels
+        elif i == 1:
+            up_in_channels = dimensions * 2
+            up_out_channels = dimensions
+            up_mid_channels = up_out_channels
+        else:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels // 2
+            up_mid_channels = up_in_channels
+        upsample_layer = upsample_convtranspose2d_layer(
+                                                        input_channels = up_in_channels,
+                                                        output_channels = up_mid_channels,
+                                                        kernel_size = 2,
+                                                        stride = 2,
+                                                        bias = bias,
+                                                       )
+        conv_layer = double_convolution(
+                                        input_channels = up_mid_channels,
+                                        output_channels = up_out_channels,
+                                        kernel_size = kernel_size,
+                                        bias = bias,
+                                        normalization = normalization,
+                                        activation = activation,
+                                       )
+        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(focal_surface, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + focal_surface + (tensor) + – +
    +
            Input focal surface data.
    +        Dimension: (1, 1, H, W)
    +
    +
    +
  • +
  • + field + – +
    +
            Input field data.
    +        Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sv_kernel ( list of torch.tensor +) – +
    +

    Learned spatially varying kernels. +Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), +where C_i, H_i, and W_i represent the channel, height, and width +of each feature at a certain scale.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, focal_surface, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    focal_surface : torch.tensor
+                    Input focal surface data.
+                    Dimension: (1, 1, H, W)
+
+    field         : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+    """
+    x = self.inc(torch.cat((focal_surface, field), dim = 1))
+    downsampling_outputs = [focal_surface]
+    downsampling_outputs.append(x)
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    sv_kernels = []
+    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+        if i == 0:
+            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+            downsampling_outputs[-1] = global_feature
+            sv_feature = [global_feature, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                          sv_feature[3]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+        else:
+            x_up = up_layer[0](downsampling_outputs[-1],
+                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+            x_up = up_layer[1](x_up)
+            downsampling_outputs[-1] = x_up
+            sv_feature = [x_up, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            if i == 1:
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+    return sv_kernels
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ unet + + +

+ + +
+

+ Bases: Module

+ + +

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class unet(torch.nn.Module):
+    """
+    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.
+    """
+
+    def __init__(
+                 self, 
+                 depth = 4,
+                 dimensions = 64, 
+                 input_channels = 2, 
+                 output_channels = 1, 
+                 bilinear = False,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(inplace = True),
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth             : int
+                            Number of upsampling and downsampling
+        dimensions        : int
+                            Number of dimensions.
+        input_channels    : int
+                            Number of input channels.
+        output_channels   : int
+                            Number of output channels.
+        bilinear          : bool
+                            Uses bilinear upsampling in upsampling layers when set True.
+        bias              : bool
+                            Set True to let convolutional layers learn a bias term.
+        activation        : torch.nn
+                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+        """
+        super(unet, self).__init__()
+        self.inc = double_convolution(
+                                      input_channels = input_channels,
+                                      mid_channels = dimensions,
+                                      output_channels = dimensions,
+                                      kernel_size = kernel_size,
+                                      bias = bias,
+                                      activation = activation
+                                     )      
+
+        self.downsampling_layers = torch.nn.ModuleList()
+        self.upsampling_layers = torch.nn.ModuleList()
+        for i in range(depth): # downsampling layers
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** (i + 1))
+            down_layer = downsample_layer(in_channels,
+                                            out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                            )
+            self.downsampling_layers.append(down_layer)      
+
+        for i in range(depth - 1, -1, -1):  # upsampling layers
+            up_in_channels = dimensions * (2 ** (i + 1))  
+            up_out_channels = dimensions * (2 ** i) 
+            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+            self.upsampling_layers.append(up_layer)
+        self.outc = torch.nn.Conv2d(
+                                    dimensions, 
+                                    output_channels,
+                                    kernel_size = kernel_size,
+                                    padding = kernel_size // 2,
+                                    bias = bias
+                                   )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        downsampling_outputs = [self.inc(x)]
+        for down_layer in self.downsampling_layers:
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate((self.upsampling_layers)):
+            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+        result = self.outc(x_up)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
                Number of upsampling and downsampling
    +
    +
    +
  • +
  • + dimensions + – +
    +
                Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + – +
    +
                Number of input channels.
    +
    +
    +
  • +
  • + output_channels + – +
    +
                Number of output channels.
    +
    +
    +
  • +
  • + bilinear + – +
    +
                Uses bilinear upsampling in upsampling layers when set True.
    +
    +
    +
  • +
  • + bias + – +
    +
                Set True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self, 
+             depth = 4,
+             dimensions = 64, 
+             input_channels = 2, 
+             output_channels = 1, 
+             bilinear = False,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(inplace = True),
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth             : int
+                        Number of upsampling and downsampling
+    dimensions        : int
+                        Number of dimensions.
+    input_channels    : int
+                        Number of input channels.
+    output_channels   : int
+                        Number of output channels.
+    bilinear          : bool
+                        Uses bilinear upsampling in upsampling layers when set True.
+    bias              : bool
+                        Set True to let convolutional layers learn a bias term.
+    activation        : torch.nn
+                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+    """
+    super(unet, self).__init__()
+    self.inc = double_convolution(
+                                  input_channels = input_channels,
+                                  mid_channels = dimensions,
+                                  output_channels = dimensions,
+                                  kernel_size = kernel_size,
+                                  bias = bias,
+                                  activation = activation
+                                 )      
+
+    self.downsampling_layers = torch.nn.ModuleList()
+    self.upsampling_layers = torch.nn.ModuleList()
+    for i in range(depth): # downsampling layers
+        in_channels = dimensions * (2 ** i)
+        out_channels = dimensions * (2 ** (i + 1))
+        down_layer = downsample_layer(in_channels,
+                                        out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                        )
+        self.downsampling_layers.append(down_layer)      
+
+    for i in range(depth - 1, -1, -1):  # upsampling layers
+        up_in_channels = dimensions * (2 ** (i + 1))  
+        up_out_channels = dimensions * (2 ** i) 
+        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+        self.upsampling_layers.append(up_layer)
+    self.outc = torch.nn.Conv2d(
+                                dimensions, 
+                                output_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    downsampling_outputs = [self.inc(x)]
+    for down_layer in self.downsampling_layers:
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate((self.upsampling_layers)):
+        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+    result = self.outc(x_up)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_convtranspose2d_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convtranspose2d layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_convtranspose2d_layer(torch.nn.Module):
+    """
+    An upsampling convtranspose2d layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 2,
+                 stride = 2,
+                 bias = False,
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        """
+        super().__init__()
+        self.up = torch.nn.ConvTranspose2d(
+                                           in_channels = input_channels,
+                                           out_channels = output_channels,
+                                           bias = bias,
+                                           kernel_size = kernel_size,
+                                           stride = stride
+                                          )
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        result = x1 + x2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 2,
+             stride = 2,
+             bias = False,
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    """
+    super().__init__()
+    self.up = torch.nn.ConvTranspose2d(
+                                       in_channels = input_channels,
+                                       out_channels = output_channels,
+                                       bias = bias,
+                                       kernel_size = kernel_size,
+                                       stride = stride
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    result = x1 + x2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convolutional layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_layer(torch.nn.Module):
+    """
+    An upsampling convolutional layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(),
+                 bilinear = True
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        bilinear        : bool
+                          If set to True, bilinear sampling is used.
+        """
+        super(upsample_layer, self).__init__()
+        if bilinear:
+            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+            self.conv = double_convolution(
+                                           input_channels = input_channels + output_channels,
+                                           mid_channels = input_channels // 2,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+        else:
+            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+            self.conv = double_convolution(
+                                           input_channels = input_channels,
+                                           mid_channels = output_channels,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """ 
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        x = torch.cat([x2, x1], dim = 1)
+        result = self.conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
  • + bilinear + – +
    +
              If set to True, bilinear sampling is used.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(),
+             bilinear = True
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    bilinear        : bool
+                      If set to True, bilinear sampling is used.
+    """
+    super(upsample_layer, self).__init__()
+    if bilinear:
+        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+        self.conv = double_convolution(
+                                       input_channels = input_channels + output_channels,
+                                       mid_channels = input_channels // 2,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+    else:
+        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+        self.conv = double_convolution(
+                                       input_channels = input_channels,
+                                       mid_channels = output_channels,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """ 
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    x = torch.cat([x2, x1], dim = 1)
+    result = self.conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ gaussian(x, multiplier=1.0) + +

+ + +
+ +

A Gaussian non-linear activation. +For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input data.
    +
    +
    +
  • +
  • + multiplier + – +
    +
           Multiplier.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( float or tensor +) – +
    +

    Ouput data.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def gaussian(x, multiplier = 1.):
+    """
+    A Gaussian non-linear activation.
+    For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+
+    Parameters
+    ----------
+    x            : float or torch.tensor
+                   Input data.
+    multiplier   : float or torch.tensor
+                   Multiplier.
+
+    Returns
+    -------
+    result       : float or torch.tensor
+                   Ouput data.
+    """
+    result = torch.exp(- (multiplier * x) ** 2)
+    return result
+
+
+
+ +
+ +
+ + +

+ swish(x) + +

+ + +
+ +

A swish non-linear activation. +For more details: https://en.wikipedia.org/wiki/Swish_function

+ + +

Parameters:

+
    +
  • + x + – +
    +
             Input.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +out ( float or tensor +) – +
    +

    Output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def swish(x):
+    """
+    A swish non-linear activation.
+    For more details: https://en.wikipedia.org/wiki/Swish_function
+
+    Parameters
+    -----------
+    x              : float or torch.tensor
+                     Input.
+
+    Returns
+    -------
+    out            : float or torch.tensor
+                     Output.
+    """
+    out = x * torch.sigmoid(x)
+    return out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ channel_gate + + +

+ + +
+

+ Bases: Module

+ + +

Channel attention module with various pooling strategies. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class channel_gate(torch.nn.Module):
+    """
+    Channel attention module with various pooling strategies.
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max']
+                ):
+        """
+        Initializes the channel gate module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the intermediate layer.
+        pool_types      : list
+                          List of pooling operations to apply.
+        """
+        super().__init__()
+        self.gate_channels = gate_channels
+        hidden_channels = gate_channels // reduction_ratio
+        if hidden_channels == 0:
+            hidden_channels = 1
+        self.mlp = torch.nn.Sequential(
+                                       convolutional_block_attention.Flatten(),
+                                       torch.nn.Linear(gate_channels, hidden_channels),
+                                       torch.nn.ReLU(),
+                                       torch.nn.Linear(hidden_channels, gate_channels)
+                                      )
+        self.pool_types = pool_types
+
+
+    def forward(self, x):
+        """
+        Forward pass of the ChannelGate module.
+
+        Applies channel-wise attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the ChannelGate module.
+
+        Returns
+        -------
+        output       : torch.tensor
+                       Output tensor after applying channel attention.
+        """
+        channel_att_sum = None
+        for pool_type in self.pool_types:
+            if pool_type == 'avg':
+                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            elif pool_type == 'max':
+                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            channel_att_raw = self.mlp(pool)
+            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+        output = x * scale
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max']) + +

+ + +
+ +

Initializes the channel gate module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the intermediate layer.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max']
+            ):
+    """
+    Initializes the channel gate module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the intermediate layer.
+    pool_types      : list
+                      List of pooling operations to apply.
+    """
+    super().__init__()
+    self.gate_channels = gate_channels
+    hidden_channels = gate_channels // reduction_ratio
+    if hidden_channels == 0:
+        hidden_channels = 1
+    self.mlp = torch.nn.Sequential(
+                                   convolutional_block_attention.Flatten(),
+                                   torch.nn.Linear(gate_channels, hidden_channels),
+                                   torch.nn.ReLU(),
+                                   torch.nn.Linear(hidden_channels, gate_channels)
+                                  )
+    self.pool_types = pool_types
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the ChannelGate module.

+

Applies channel-wise attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the ChannelGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor after applying channel attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the ChannelGate module.
+
+    Applies channel-wise attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the ChannelGate module.
+
+    Returns
+    -------
+    output       : torch.tensor
+                   Output tensor after applying channel attention.
+    """
+    channel_att_sum = None
+    for pool_type in self.pool_types:
+        if pool_type == 'avg':
+            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        elif pool_type == 'max':
+            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        channel_att_raw = self.mlp(pool)
+        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+    output = x * scale
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolution_layer + + +

+ + +
+

+ Bases: Module

+ + +

A convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolution_layer(torch.nn.Module):
+    """
+    A convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 bias = False,
+                 stride = 1,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        layers = [
+            torch.nn.Conv2d(
+                            input_channels,
+                            output_channels,
+                            kernel_size = kernel_size,
+                            stride = stride,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           )
+        ]
+        if normalization:
+            layers.append(torch.nn.BatchNorm2d(output_channels))
+        if activation:
+            layers.append(activation)
+        self.model = torch.nn.Sequential(*layers)
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             bias = False,
+             stride = 1,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    layers = [
+        torch.nn.Conv2d(
+                        input_channels,
+                        output_channels,
+                        kernel_size = kernel_size,
+                        stride = stride,
+                        padding = kernel_size // 2,
+                        bias = bias
+                       )
+    ]
+    if normalization:
+        layers.append(torch.nn.BatchNorm2d(output_channels))
+    if activation:
+        layers.append(activation)
+    self.model = torch.nn.Sequential(*layers)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolutional_block_attention + + +

+ + +
+

+ Bases: Module

+ + +

Convolutional Block Attention Module (CBAM) class. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolutional_block_attention(torch.nn.Module):
+    """
+    Convolutional Block Attention Module (CBAM) class. 
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max'], 
+                 no_spatial = False
+                ):
+        """
+        Initializes the convolutional block attention module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the channel attention.
+        pool_types      : list
+                          List of pooling operations to apply for channel attention.
+        no_spatial      : bool
+                          If True, spatial attention is not applied.
+        """
+        super(convolutional_block_attention, self).__init__()
+        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+        self.no_spatial = no_spatial
+        if not no_spatial:
+            self.spatial_gate = spatial_gate()
+
+
+    class Flatten(torch.nn.Module):
+        """
+        Flattens the input tensor to a 2D matrix.
+        """
+        def forward(self, x):
+            return x.view(x.size(0), -1)
+
+
+    def forward(self, x):
+        """
+        Forward pass of the convolutional block attention module.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the CBAM module.
+
+        Returns
+        -------
+        x_out        : torch.tensor
+                       Output tensor after applying channel and spatial attention.
+        """
+        x_out = self.channel_gate(x)
+        if not self.no_spatial:
+            x_out = self.spatial_gate(x_out)
+        return x_out
+
+
+ + + +
+ + + + + + + + +
+ + + +

+ Flatten + + +

+ + +
+

+ Bases: Module

+ + +

Flattens the input tensor to a 2D matrix.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class Flatten(torch.nn.Module):
+    """
+    Flattens the input tensor to a 2D matrix.
+    """
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False) + +

+ + +
+ +

Initializes the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the channel attention.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply for channel attention.
    +
    +
    +
  • +
  • + no_spatial + – +
    +
              If True, spatial attention is not applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max'], 
+             no_spatial = False
+            ):
+    """
+    Initializes the convolutional block attention module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the channel attention.
+    pool_types      : list
+                      List of pooling operations to apply for channel attention.
+    no_spatial      : bool
+                      If True, spatial attention is not applied.
+    """
+    super(convolutional_block_attention, self).__init__()
+    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+    self.no_spatial = no_spatial
+    if not no_spatial:
+        self.spatial_gate = spatial_gate()
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the CBAM module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +x_out ( tensor +) – +
    +

    Output tensor after applying channel and spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the convolutional block attention module.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the CBAM module.
+
+    Returns
+    -------
+    x_out        : torch.tensor
+                   Output tensor after applying channel and spatial attention.
+    """
+    x_out = self.channel_gate(x)
+    if not self.no_spatial:
+        x_out = self.spatial_gate(x_out)
+    return x_out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ double_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A double convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class double_convolution(torch.nn.Module):
+    """
+    A double convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = None,
+                 output_channels = 2,
+                 kernel_size = 3, 
+                 bias = False,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        Double convolution model.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of channels in the hidden layer between two convolutions.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        if isinstance(mid_channels, type(None)):
+            mid_channels = output_channels
+        self.activation = activation
+        self.model = torch.nn.Sequential(
+                                         convolution_layer(
+                                                           input_channels = input_channels,
+                                                           output_channels = mid_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          ),
+                                         convolution_layer(
+                                                           input_channels = mid_channels,
+                                                           output_channels = output_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          )
+                                        )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

Double convolution model.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of channels in the hidden layer between two convolutions.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = None,
+             output_channels = 2,
+             kernel_size = 3, 
+             bias = False,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    Double convolution model.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of channels in the hidden layer between two convolutions.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    if isinstance(mid_channels, type(None)):
+        mid_channels = output_channels
+    self.activation = activation
+    self.model = torch.nn.Sequential(
+                                     convolution_layer(
+                                                       input_channels = input_channels,
+                                                       output_channels = mid_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      ),
+                                     convolution_layer(
+                                                       input_channels = mid_channels,
+                                                       output_channels = output_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      )
+                                    )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ downsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

A downscaling component followed by a double convolution.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class downsample_layer(torch.nn.Module):
+    """
+    A downscaling component followed by a double convolution.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.maxpool_conv = torch.nn.Sequential(
+                                                torch.nn.MaxPool2d(2),
+                                                double_convolution(
+                                                                   input_channels = input_channels,
+                                                                   mid_channels = output_channels,
+                                                                   output_channels = output_channels,
+                                                                   kernel_size = kernel_size,
+                                                                   bias = bias,
+                                                                   activation = activation
+                                                                  )
+                                               )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x              : torch.tensor
+                         First input data.
+
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.maxpool_conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.maxpool_conv = torch.nn.Sequential(
+                                            torch.nn.MaxPool2d(2),
+                                            double_convolution(
+                                                               input_channels = input_channels,
+                                                               mid_channels = output_channels,
+                                                               output_channels = output_channels,
+                                                               kernel_size = kernel_size,
+                                                               bias = bias,
+                                                               activation = activation
+                                                              )
+                                           )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
             First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x              : torch.tensor
+                     First input data.
+
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.maxpool_conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_feature_module + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies them to another input tensor via learned transformations.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_feature_module(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies them to another input tensor via learned transformations.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 mid_channels,
+                 output_channels,
+                 kernel_size,
+                 bias = False,
+                 normalization = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels  : int
+                          Number of mid channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.transformations_1 = global_transformations(input_channels, output_channels)
+        self.global_features_1 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.global_features_2 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        global_tensor_1 = self.transformations_1(x1, x2)
+        y1 = self.global_features_1(global_tensor_1)
+        y2 = self.global_features_2(y1)
+        global_tensor_2 = self.transformations_2(y1, y2)
+        return global_tensor_2
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of mid channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             mid_channels,
+             output_channels,
+             kernel_size,
+             bias = False,
+             normalization = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels  : int
+                      Number of mid channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.transformations_1 = global_transformations(input_channels, output_channels)
+    self.global_features_1 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.global_features_2 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    global_tensor_1 = self.transformations_1(x1, x2)
+    y1 = self.global_features_1(global_tensor_1)
+    y2 = self.global_features_2(y1)
+    global_tensor_2 = self.transformations_2(y1, y2)
+    return global_tensor_2
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_transformations + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies learned transformations to another input tensor.

+

This implementation is adapted from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Reference: +J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_transformations(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies learned transformations to another input tensor.
+
+    This implementation is adapted from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Reference:
+    J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        """
+        super().__init__()
+        self.global_feature_1 = torch.nn.Sequential(
+            torch.nn.Linear(input_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True),
+        )
+        self.global_feature_2 = torch.nn.Sequential(
+            torch.nn.Linear(output_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True)
+        )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        y = torch.mean(x2, dim = (2, 3))
+        y1 = self.global_feature_1(y)
+        y2 = self.global_feature_2(y1)
+        y1 = y1.unsqueeze(2).unsqueeze(3)
+        y2 = y2.unsqueeze(2).unsqueeze(3)
+        result = x1 * y1 + y2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    """
+    super().__init__()
+    self.global_feature_1 = torch.nn.Sequential(
+        torch.nn.Linear(input_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True),
+    )
+    self.global_feature_2 = torch.nn.Sequential(
+        torch.nn.Linear(output_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True)
+    )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    y = torch.mean(x2, dim = (2, 3))
+    y1 = self.global_feature_1(y)
+    y2 = self.global_feature_2(y1)
+    y1 = y1.unsqueeze(2).unsqueeze(3)
+    y2 = y2.unsqueeze(2).unsqueeze(3)
+    result = x1 * y1 + y2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ non_local_layer + + +

+ + +
+

+ Bases: Module

+ + +

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class non_local_layer(torch.nn.Module):
+    """
+    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)
+    """
+    def __init__(
+                 self,
+                 input_channels = 1024,
+                 bottleneck_channels = 512,
+                 kernel_size = 1,
+                 bias = False,
+                ):
+        """
+
+        Parameters
+        ----------
+        input_channels      : int
+                              Number of input channels.
+        bottleneck_channels : int
+                              Number of middle channels.
+        kernel_size         : int
+                              Kernel size.
+        bias                : bool 
+                              Set to True to let convolutional layers have bias term.
+        """
+        super(non_local_layer, self).__init__()
+        self.input_channels = input_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.g = torch.nn.Conv2d(
+                                 self.input_channels, 
+                                 self.bottleneck_channels,
+                                 kernel_size = kernel_size,
+                                 padding = kernel_size // 2,
+                                 bias = bias
+                                )
+        self.W_z = torch.nn.Sequential(
+                                       torch.nn.Conv2d(
+                                                       self.bottleneck_channels,
+                                                       self.input_channels, 
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       padding = kernel_size // 2
+                                                      ),
+                                       torch.nn.BatchNorm2d(self.input_channels)
+                                      )
+        torch.nn.init.constant_(self.W_z[1].weight, 0)   
+        torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+    def forward(self, x):
+        """
+        Forward model [zi = Wzyi + xi]
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          First input data.                       
+
+
+        Returns
+        ----------
+        z               : torch.tensor
+                          Estimated output.
+        """
+        batch_size, channels, height, width = x.size()
+        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+        attn = torch.nn.functional.softmax(attn, dim=-1)
+        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+        W_y = self.W_z(y)
+        z = W_y + x
+        return z
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + input_channels + – +
    +
                  Number of input channels.
    +
    +
    +
  • +
  • + bottleneck_channels + (int, default: + 512 +) + – +
    +
                  Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
                  Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
                  Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 1024,
+             bottleneck_channels = 512,
+             kernel_size = 1,
+             bias = False,
+            ):
+    """
+
+    Parameters
+    ----------
+    input_channels      : int
+                          Number of input channels.
+    bottleneck_channels : int
+                          Number of middle channels.
+    kernel_size         : int
+                          Kernel size.
+    bias                : bool 
+                          Set to True to let convolutional layers have bias term.
+    """
+    super(non_local_layer, self).__init__()
+    self.input_channels = input_channels
+    self.bottleneck_channels = bottleneck_channels
+    self.g = torch.nn.Conv2d(
+                             self.input_channels, 
+                             self.bottleneck_channels,
+                             kernel_size = kernel_size,
+                             padding = kernel_size // 2,
+                             bias = bias
+                            )
+    self.W_z = torch.nn.Sequential(
+                                   torch.nn.Conv2d(
+                                                   self.bottleneck_channels,
+                                                   self.input_channels, 
+                                                   kernel_size = kernel_size,
+                                                   bias = bias,
+                                                   padding = kernel_size // 2
+                                                  ),
+                                   torch.nn.BatchNorm2d(self.input_channels)
+                                  )
+    torch.nn.init.constant_(self.W_z[1].weight, 0)   
+    torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model [zi = Wzyi + xi]

+ + +

Parameters:

+
    +
  • + x + – +
    +
              First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +z ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model [zi = Wzyi + xi]
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      First input data.                       
+
+
+    Returns
+    ----------
+    z               : torch.tensor
+                      Estimated output.
+    """
+    batch_size, channels, height, width = x.size()
+    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+    attn = torch.nn.functional.softmax(attn, dim=-1)
+    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+    W_y = self.W_z(y)
+    z = W_y + x
+    return z
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ normalization + + +

+ + +
+

+ Bases: Module

+ + +

A normalization layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class normalization(torch.nn.Module):
+    """
+    A normalization layer.
+    """
+    def __init__(
+                 self,
+                 dim = 1,
+                ):
+        """
+        Normalization layer.
+
+
+        Parameters
+        ----------
+        dim             : int
+                          Dimension (axis) to normalize.
+        """
+        super().__init__()
+        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+        mean = torch.mean(x, dim = 1, keepdim = True)
+        result =  (x - mean) * (var + eps).rsqrt() * self.k
+        return result 
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dim=1) + +

+ + +
+ +

Normalization layer.

+ + +

Parameters:

+
    +
  • + dim + – +
    +
              Dimension (axis) to normalize.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             dim = 1,
+            ):
+    """
+    Normalization layer.
+
+
+    Parameters
+    ----------
+    dim             : int
+                      Dimension (axis) to normalize.
+    """
+    super().__init__()
+    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+    mean = torch.mean(x, dim = 1, keepdim = True)
+    result =  (x - mean) * (var + eps).rsqrt() * self.k
+    return result 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ positional_encoder + + +

+ + +
+

+ Bases: Module

+ + +

A positional encoder module.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class positional_encoder(torch.nn.Module):
+    """
+    A positional encoder module.
+    """
+
+    def __init__(self, L):
+        """
+        A positional encoder module.
+
+        Parameters
+        ----------
+        L                   : int
+                              Positional encoding level.
+        """
+        super(positional_encoder, self).__init__()
+        self.L = L
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          Input data.
+
+        Returns
+        ----------
+        result          : torch.tensor
+                          Result of the forward operation
+        """
+        B, C = x.shape
+        x = x.view(B, C, 1)
+        results = [x]
+        for i in range(1, self.L + 1):
+            freq = (2 ** i) * math.pi
+            cos_x = torch.cos(freq * x)
+            sin_x = torch.sin(freq * x)
+            results.append(cos_x)
+            results.append(sin_x)
+        results = torch.cat(results, dim=2)
+        results = results.permute(0, 2, 1)
+        results = results.reshape(B, -1)
+        return results
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(L) + +

+ + +
+ +

A positional encoder module.

+ + +

Parameters:

+
    +
  • + L + – +
    +
                  Positional encoding level.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(self, L):
+    """
+    A positional encoder module.
+
+    Parameters
+    ----------
+    L                   : int
+                          Positional encoding level.
+    """
+    super(positional_encoder, self).__init__()
+    self.L = L
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
              Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      Input data.
+
+    Returns
+    ----------
+    result          : torch.tensor
+                      Result of the forward operation
+    """
+    B, C = x.shape
+    x = x.view(B, C, 1)
+    results = [x]
+    for i in range(1, self.L + 1):
+        freq = (2 ** i) * math.pi
+        cos_x = torch.cos(freq * x)
+        sin_x = torch.sin(freq * x)
+        results.append(cos_x)
+        results.append(sin_x)
+    results = torch.cat(results, dim=2)
+    results = results.permute(0, 2, 1)
+    results = results.reshape(B, -1)
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_attention_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual block with an attention layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_attention_layer(torch.nn.Module):
+    """
+    A residual block with an attention layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 1,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        An attention layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int or optioal
+                          Number of input channels.
+        output_channels : int or optional
+                          Number of middle channels.
+        kernel_size     : int or optional
+                          Kernel size.
+        bias            : bool or optional
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn or optional
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution0 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.convolution1 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.final_layer = torch.nn.Sequential(
+                                               self.activation,
+                                               torch.nn.Conv2d(
+                                                               output_channels,
+                                                               output_channels,
+                                                               kernel_size = kernel_size,
+                                                               padding = kernel_size // 2,
+                                                               bias = bias
+                                                              )
+                                              )
+
+
+    def forward(self, x0, x1):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x0             : torch.tensor
+                         First input data.
+
+        x1             : torch.tensor
+                         Seconnd input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        y0 = self.convolution0(x0)
+        y1 = self.convolution1(x1)
+        y2 = torch.add(y0, y1)
+        result = self.final_layer(y2) * x0
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

An attention layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int or optional, default: + 2 +) + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 1,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    An attention layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int or optioal
+                      Number of input channels.
+    output_channels : int or optional
+                      Number of middle channels.
+    kernel_size     : int or optional
+                      Kernel size.
+    bias            : bool or optional
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn or optional
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution0 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.convolution1 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.final_layer = torch.nn.Sequential(
+                                           self.activation,
+                                           torch.nn.Conv2d(
+                                                           output_channels,
+                                                           output_channels,
+                                                           kernel_size = kernel_size,
+                                                           padding = kernel_size // 2,
+                                                           bias = bias
+                                                          )
+                                          )
+
+
+
+ +
+ +
+ + +

+ forward(x0, x1) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x0 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x1 + – +
    +
             Seconnd input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x0, x1):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x0             : torch.tensor
+                     First input data.
+
+    x1             : torch.tensor
+                     Seconnd input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    y0 = self.convolution0(x0)
+    y1 = self.convolution1(x1)
+    y2 = torch.add(y0, y1)
+    result = self.final_layer(y2) * x0
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_layer(torch.nn.Module):
+    """
+    A residual layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = 16,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of middle channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution = double_convolution(
+                                              input_channels,
+                                              mid_channels = mid_channels,
+                                              output_channels = input_channels,
+                                              kernel_size = kernel_size,
+                                              bias = bias,
+                                              activation = activation
+                                             )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        x0 = self.convolution(x)
+        return x + x0
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = 16,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of middle channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution = double_convolution(
+                                          input_channels,
+                                          mid_channels = mid_channels,
+                                          output_channels = input_channels,
+                                          kernel_size = kernel_size,
+                                          bias = bias,
+                                          activation = activation
+                                         )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    x0 = self.convolution(x)
+    return x + x0
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatial_gate + + +

+ + +
+

+ Bases: Module

+ + +

Spatial attention module that applies a convolution layer after channel pooling. +This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class spatial_gate(torch.nn.Module):
+    """
+    Spatial attention module that applies a convolution layer after channel pooling.
+    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
+    """
+    def __init__(self):
+        """
+        Initializes the spatial gate module.
+        """
+        super().__init__()
+        kernel_size = 7
+        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+    def channel_pool(self, x):
+        """
+        Applies max and average pooling on the channels.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input tensor.
+
+        Returns
+        -------
+        output        : torch.tensor
+                        Output tensor.
+        """
+        max_pool = torch.max(x, 1)[0].unsqueeze(1)
+        avg_pool = torch.mean(x, 1).unsqueeze(1)
+        output = torch.cat((max_pool, avg_pool), dim=1)
+        return output
+
+
+    def forward(self, x):
+        """
+        Forward pass of the SpatialGate module.
+
+        Applies spatial attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the SpatialGate module.
+
+        Returns
+        -------
+        scaled_x     : torch.tensor
+                       Output tensor after applying spatial attention.
+        """
+        x_compress = self.channel_pool(x)
+        x_out = self.spatial(x_compress)
+        scale = torch.sigmoid(x_out)
+        scaled_x = x * scale
+        return scaled_x
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the spatial gate module.

+ +
+ Source code in odak/learn/models/components.py +
def __init__(self):
+    """
+    Initializes the spatial gate module.
+    """
+    super().__init__()
+    kernel_size = 7
+    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+
+ +
+ +
+ + +

+ channel_pool(x) + +

+ + +
+ +

Applies max and average pooling on the channels.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def channel_pool(self, x):
+    """
+    Applies max and average pooling on the channels.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input tensor.
+
+    Returns
+    -------
+    output        : torch.tensor
+                    Output tensor.
+    """
+    max_pool = torch.max(x, 1)[0].unsqueeze(1)
+    avg_pool = torch.mean(x, 1).unsqueeze(1)
+    output = torch.cat((max_pool, avg_pool), dim=1)
+    return output
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the SpatialGate module.

+

Applies spatial attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the SpatialGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +scaled_x ( tensor +) – +
    +

    Output tensor after applying spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the SpatialGate module.
+
+    Applies spatial attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the SpatialGate module.
+
+    Returns
+    -------
+    scaled_x     : torch.tensor
+                   Output tensor after applying spatial attention.
+    """
+    x_compress = self.channel_pool(x)
+    x_out = self.spatial(x_compress)
+    scale = torch.sigmoid(x_out)
+    scaled_x = x * scale
+    return scaled_x
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive convolution layer.

+ + +
+ References +

C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." +C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." +C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_convolution(torch.nn.Module):
+    """
+    A spatially adaptive convolution layer.
+
+    References
+    ----------
+
+    C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions."
+    C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation."
+    C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn.Module
+                          Activation function to apply. If None, no activation is applied.
+        """
+        super(spatially_adaptive_convolution, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        sa_output          : torch.tensor
+                            Estimated output tensor.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Resize weight to match the input channels and kernel size
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply spatially varying kernels
+        sv_feature = input_feature * sv_kernel
+
+        # Perform matrix multiplication
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+        return sa_output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Activation function to apply. If None, no activation is applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn.Module
+                      Activation function to apply. If None, no activation is applied.
+    """
+    super(spatially_adaptive_convolution, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sa_output ( tensor +) – +
    +

    Estimated output tensor. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    sa_output          : torch.tensor
+                        Estimated output tensor.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Resize weight to match the input channels and kernel size
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply spatially varying kernels
+    sv_feature = input_feature * sv_kernel
+
+    # Perform matrix multiplication
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+    return sa_output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_module + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive module that combines learned spatially adaptive convolutions.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_module(torch.nn.Module):
+    """
+    A spatially adaptive module that combines learned spatially adaptive convolutions.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive module.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super(spatially_adaptive_module, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.weight_output_channels = self.output_channels - 1
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.weight_output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive module.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        output             : torch.tensor
+                            Combined output tensor from standard and spatially adaptive convolutions.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Apply sv_kernel to the input_feature
+        sv_feature = input_feature * sv_kernel
+
+        # Original spatially varying convolution output
+        sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                           1,
+                                                            1,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+        # Reshape weight for spatially adaptive convolution
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply si_kernel on sv convolution output
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+
+        # Combine the outputs and apply activation function
+        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive module.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive module.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super(spatially_adaptive_module, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.weight_output_channels = self.output_channels - 1
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.weight_output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Combined output tensor from standard and spatially adaptive convolutions. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive module.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    output             : torch.tensor
+                        Combined output tensor from standard and spatially adaptive convolutions.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Apply sv_kernel to the input_feature
+    sv_feature = input_feature * sv_kernel
+
+    # Original spatially varying convolution output
+    sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                       1,
+                                                        1,
+                                                        (x.size(-2) // self.stride),
+                                                        (x.size(-1) // self.stride)
+                                                       )
+
+    # Reshape weight for spatially adaptive convolution
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply si_kernel on sv convolution output
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+    # Combine the outputs and apply activation function
+    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_convtranspose2d_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convtranspose2d layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_convtranspose2d_layer(torch.nn.Module):
+    """
+    An upsampling convtranspose2d layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 2,
+                 stride = 2,
+                 bias = False,
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        """
+        super().__init__()
+        self.up = torch.nn.ConvTranspose2d(
+                                           in_channels = input_channels,
+                                           out_channels = output_channels,
+                                           bias = bias,
+                                           kernel_size = kernel_size,
+                                           stride = stride
+                                          )
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        result = x1 + x2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 2,
+             stride = 2,
+             bias = False,
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    """
+    super().__init__()
+    self.up = torch.nn.ConvTranspose2d(
+                                       in_channels = input_channels,
+                                       out_channels = output_channels,
+                                       bias = bias,
+                                       kernel_size = kernel_size,
+                                       stride = stride
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    result = x1 + x2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convolutional layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_layer(torch.nn.Module):
+    """
+    An upsampling convolutional layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(),
+                 bilinear = True
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        bilinear        : bool
+                          If set to True, bilinear sampling is used.
+        """
+        super(upsample_layer, self).__init__()
+        if bilinear:
+            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+            self.conv = double_convolution(
+                                           input_channels = input_channels + output_channels,
+                                           mid_channels = input_channels // 2,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+        else:
+            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+            self.conv = double_convolution(
+                                           input_channels = input_channels,
+                                           mid_channels = output_channels,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """ 
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        x = torch.cat([x2, x1], dim = 1)
+        result = self.conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
  • + bilinear + – +
    +
              If set to True, bilinear sampling is used.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(),
+             bilinear = True
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    bilinear        : bool
+                      If set to True, bilinear sampling is used.
+    """
+    super(upsample_layer, self).__init__()
+    if bilinear:
+        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+        self.conv = double_convolution(
+                                       input_channels = input_channels + output_channels,
+                                       mid_channels = input_channels // 2,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+    else:
+        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+        self.conv = double_convolution(
+                                       input_channels = input_channels,
+                                       mid_channels = output_channels,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """ 
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    x = torch.cat([x2, x1], dim = 1)
+    result = self.conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ gaussian(x, multiplier=1.0) + +

+ + +
+ +

A Gaussian non-linear activation. +For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input data.
    +
    +
    +
  • +
  • + multiplier + – +
    +
           Multiplier.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( float or tensor +) – +
    +

    Ouput data.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def gaussian(x, multiplier = 1.):
+    """
+    A Gaussian non-linear activation.
+    For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+
+    Parameters
+    ----------
+    x            : float or torch.tensor
+                   Input data.
+    multiplier   : float or torch.tensor
+                   Multiplier.
+
+    Returns
+    -------
+    result       : float or torch.tensor
+                   Ouput data.
+    """
+    result = torch.exp(- (multiplier * x) ** 2)
+    return result
+
+
+
+ +
+ +
+ + +

+ swish(x) + +

+ + +
+ +

A swish non-linear activation. +For more details: https://en.wikipedia.org/wiki/Swish_function

+ + +

Parameters:

+
    +
  • + x + – +
    +
             Input.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +out ( float or tensor +) – +
    +

    Output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def swish(x):
+    """
+    A swish non-linear activation.
+    For more details: https://en.wikipedia.org/wiki/Swish_function
+
+    Parameters
+    -----------
+    x              : float or torch.tensor
+                     Input.
+
+    Returns
+    -------
+    out            : float or torch.tensor
+                     Output.
+    """
+    out = x * torch.sigmoid(x)
+    return out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ channel_gate + + +

+ + +
+

+ Bases: Module

+ + +

Channel attention module with various pooling strategies. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class channel_gate(torch.nn.Module):
+    """
+    Channel attention module with various pooling strategies.
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max']
+                ):
+        """
+        Initializes the channel gate module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the intermediate layer.
+        pool_types      : list
+                          List of pooling operations to apply.
+        """
+        super().__init__()
+        self.gate_channels = gate_channels
+        hidden_channels = gate_channels // reduction_ratio
+        if hidden_channels == 0:
+            hidden_channels = 1
+        self.mlp = torch.nn.Sequential(
+                                       convolutional_block_attention.Flatten(),
+                                       torch.nn.Linear(gate_channels, hidden_channels),
+                                       torch.nn.ReLU(),
+                                       torch.nn.Linear(hidden_channels, gate_channels)
+                                      )
+        self.pool_types = pool_types
+
+
+    def forward(self, x):
+        """
+        Forward pass of the ChannelGate module.
+
+        Applies channel-wise attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the ChannelGate module.
+
+        Returns
+        -------
+        output       : torch.tensor
+                       Output tensor after applying channel attention.
+        """
+        channel_att_sum = None
+        for pool_type in self.pool_types:
+            if pool_type == 'avg':
+                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            elif pool_type == 'max':
+                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            channel_att_raw = self.mlp(pool)
+            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+        output = x * scale
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max']) + +

+ + +
+ +

Initializes the channel gate module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the intermediate layer.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max']
+            ):
+    """
+    Initializes the channel gate module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the intermediate layer.
+    pool_types      : list
+                      List of pooling operations to apply.
+    """
+    super().__init__()
+    self.gate_channels = gate_channels
+    hidden_channels = gate_channels // reduction_ratio
+    if hidden_channels == 0:
+        hidden_channels = 1
+    self.mlp = torch.nn.Sequential(
+                                   convolutional_block_attention.Flatten(),
+                                   torch.nn.Linear(gate_channels, hidden_channels),
+                                   torch.nn.ReLU(),
+                                   torch.nn.Linear(hidden_channels, gate_channels)
+                                  )
+    self.pool_types = pool_types
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the ChannelGate module.

+

Applies channel-wise attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the ChannelGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor after applying channel attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the ChannelGate module.
+
+    Applies channel-wise attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the ChannelGate module.
+
+    Returns
+    -------
+    output       : torch.tensor
+                   Output tensor after applying channel attention.
+    """
+    channel_att_sum = None
+    for pool_type in self.pool_types:
+        if pool_type == 'avg':
+            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        elif pool_type == 'max':
+            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        channel_att_raw = self.mlp(pool)
+        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+    output = x * scale
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolution_layer + + +

+ + +
+

+ Bases: Module

+ + +

A convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolution_layer(torch.nn.Module):
+    """
+    A convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 bias = False,
+                 stride = 1,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        layers = [
+            torch.nn.Conv2d(
+                            input_channels,
+                            output_channels,
+                            kernel_size = kernel_size,
+                            stride = stride,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           )
+        ]
+        if normalization:
+            layers.append(torch.nn.BatchNorm2d(output_channels))
+        if activation:
+            layers.append(activation)
+        self.model = torch.nn.Sequential(*layers)
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             bias = False,
+             stride = 1,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    layers = [
+        torch.nn.Conv2d(
+                        input_channels,
+                        output_channels,
+                        kernel_size = kernel_size,
+                        stride = stride,
+                        padding = kernel_size // 2,
+                        bias = bias
+                       )
+    ]
+    if normalization:
+        layers.append(torch.nn.BatchNorm2d(output_channels))
+    if activation:
+        layers.append(activation)
+    self.model = torch.nn.Sequential(*layers)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolutional_block_attention + + +

+ + +
+

+ Bases: Module

+ + +

Convolutional Block Attention Module (CBAM) class. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolutional_block_attention(torch.nn.Module):
+    """
+    Convolutional Block Attention Module (CBAM) class. 
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max'], 
+                 no_spatial = False
+                ):
+        """
+        Initializes the convolutional block attention module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the channel attention.
+        pool_types      : list
+                          List of pooling operations to apply for channel attention.
+        no_spatial      : bool
+                          If True, spatial attention is not applied.
+        """
+        super(convolutional_block_attention, self).__init__()
+        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+        self.no_spatial = no_spatial
+        if not no_spatial:
+            self.spatial_gate = spatial_gate()
+
+
+    class Flatten(torch.nn.Module):
+        """
+        Flattens the input tensor to a 2D matrix.
+        """
+        def forward(self, x):
+            return x.view(x.size(0), -1)
+
+
+    def forward(self, x):
+        """
+        Forward pass of the convolutional block attention module.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the CBAM module.
+
+        Returns
+        -------
+        x_out        : torch.tensor
+                       Output tensor after applying channel and spatial attention.
+        """
+        x_out = self.channel_gate(x)
+        if not self.no_spatial:
+            x_out = self.spatial_gate(x_out)
+        return x_out
+
+
+ + + +
+ + + + + + + + +
+ + + +

+ Flatten + + +

+ + +
+

+ Bases: Module

+ + +

Flattens the input tensor to a 2D matrix.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class Flatten(torch.nn.Module):
+    """
+    Flattens the input tensor to a 2D matrix.
+    """
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False) + +

+ + +
+ +

Initializes the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the channel attention.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply for channel attention.
    +
    +
    +
  • +
  • + no_spatial + – +
    +
              If True, spatial attention is not applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max'], 
+             no_spatial = False
+            ):
+    """
+    Initializes the convolutional block attention module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the channel attention.
+    pool_types      : list
+                      List of pooling operations to apply for channel attention.
+    no_spatial      : bool
+                      If True, spatial attention is not applied.
+    """
+    super(convolutional_block_attention, self).__init__()
+    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+    self.no_spatial = no_spatial
+    if not no_spatial:
+        self.spatial_gate = spatial_gate()
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the CBAM module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +x_out ( tensor +) – +
    +

    Output tensor after applying channel and spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the convolutional block attention module.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the CBAM module.
+
+    Returns
+    -------
+    x_out        : torch.tensor
+                   Output tensor after applying channel and spatial attention.
+    """
+    x_out = self.channel_gate(x)
+    if not self.no_spatial:
+        x_out = self.spatial_gate(x_out)
+    return x_out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ double_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A double convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class double_convolution(torch.nn.Module):
+    """
+    A double convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = None,
+                 output_channels = 2,
+                 kernel_size = 3, 
+                 bias = False,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        Double convolution model.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of channels in the hidden layer between two convolutions.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        if isinstance(mid_channels, type(None)):
+            mid_channels = output_channels
+        self.activation = activation
+        self.model = torch.nn.Sequential(
+                                         convolution_layer(
+                                                           input_channels = input_channels,
+                                                           output_channels = mid_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          ),
+                                         convolution_layer(
+                                                           input_channels = mid_channels,
+                                                           output_channels = output_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          )
+                                        )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

Double convolution model.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of channels in the hidden layer between two convolutions.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = None,
+             output_channels = 2,
+             kernel_size = 3, 
+             bias = False,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    Double convolution model.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of channels in the hidden layer between two convolutions.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    if isinstance(mid_channels, type(None)):
+        mid_channels = output_channels
+    self.activation = activation
+    self.model = torch.nn.Sequential(
+                                     convolution_layer(
+                                                       input_channels = input_channels,
+                                                       output_channels = mid_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      ),
+                                     convolution_layer(
+                                                       input_channels = mid_channels,
+                                                       output_channels = output_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      )
+                                    )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ downsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

A downscaling component followed by a double convolution.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class downsample_layer(torch.nn.Module):
+    """
+    A downscaling component followed by a double convolution.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.maxpool_conv = torch.nn.Sequential(
+                                                torch.nn.MaxPool2d(2),
+                                                double_convolution(
+                                                                   input_channels = input_channels,
+                                                                   mid_channels = output_channels,
+                                                                   output_channels = output_channels,
+                                                                   kernel_size = kernel_size,
+                                                                   bias = bias,
+                                                                   activation = activation
+                                                                  )
+                                               )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x              : torch.tensor
+                         First input data.
+
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.maxpool_conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.maxpool_conv = torch.nn.Sequential(
+                                            torch.nn.MaxPool2d(2),
+                                            double_convolution(
+                                                               input_channels = input_channels,
+                                                               mid_channels = output_channels,
+                                                               output_channels = output_channels,
+                                                               kernel_size = kernel_size,
+                                                               bias = bias,
+                                                               activation = activation
+                                                              )
+                                           )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
             First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x              : torch.tensor
+                     First input data.
+
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.maxpool_conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_feature_module + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies them to another input tensor via learned transformations.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_feature_module(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies them to another input tensor via learned transformations.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 mid_channels,
+                 output_channels,
+                 kernel_size,
+                 bias = False,
+                 normalization = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels  : int
+                          Number of mid channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.transformations_1 = global_transformations(input_channels, output_channels)
+        self.global_features_1 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.global_features_2 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        global_tensor_1 = self.transformations_1(x1, x2)
+        y1 = self.global_features_1(global_tensor_1)
+        y2 = self.global_features_2(y1)
+        global_tensor_2 = self.transformations_2(y1, y2)
+        return global_tensor_2
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of mid channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             mid_channels,
+             output_channels,
+             kernel_size,
+             bias = False,
+             normalization = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels  : int
+                      Number of mid channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.transformations_1 = global_transformations(input_channels, output_channels)
+    self.global_features_1 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.global_features_2 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    global_tensor_1 = self.transformations_1(x1, x2)
+    y1 = self.global_features_1(global_tensor_1)
+    y2 = self.global_features_2(y1)
+    global_tensor_2 = self.transformations_2(y1, y2)
+    return global_tensor_2
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_transformations + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies learned transformations to another input tensor.

+

This implementation is adapted from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Reference: +J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_transformations(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies learned transformations to another input tensor.
+
+    This implementation is adapted from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Reference:
+    J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        """
+        super().__init__()
+        self.global_feature_1 = torch.nn.Sequential(
+            torch.nn.Linear(input_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True),
+        )
+        self.global_feature_2 = torch.nn.Sequential(
+            torch.nn.Linear(output_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True)
+        )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        y = torch.mean(x2, dim = (2, 3))
+        y1 = self.global_feature_1(y)
+        y2 = self.global_feature_2(y1)
+        y1 = y1.unsqueeze(2).unsqueeze(3)
+        y2 = y2.unsqueeze(2).unsqueeze(3)
+        result = x1 * y1 + y2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    """
+    super().__init__()
+    self.global_feature_1 = torch.nn.Sequential(
+        torch.nn.Linear(input_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True),
+    )
+    self.global_feature_2 = torch.nn.Sequential(
+        torch.nn.Linear(output_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True)
+    )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    y = torch.mean(x2, dim = (2, 3))
+    y1 = self.global_feature_1(y)
+    y2 = self.global_feature_2(y1)
+    y1 = y1.unsqueeze(2).unsqueeze(3)
+    y2 = y2.unsqueeze(2).unsqueeze(3)
+    result = x1 * y1 + y2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ multi_layer_perceptron + + +

+ + +
+

+ Bases: Module

+ + +

A multi-layer perceptron model.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
class multi_layer_perceptron(torch.nn.Module):
+    """
+    A multi-layer perceptron model.
+    """
+
+    def __init__(self,
+                 dimensions,
+                 activation = torch.nn.ReLU(),
+                 bias = False,
+                 model_type = 'conventional',
+                 siren_multiplier = 1.,
+                 input_multiplier = None
+                ):
+        """
+        Parameters
+        ----------
+        dimensions        : list
+                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+        activation        : torch.nn
+                            Nonlinear activation function.
+                            Default is `torch.nn.ReLU()`.
+        bias              : bool
+                            If set to True, linear layers will include biases.
+        siren_multiplier  : float
+                            When using `SIREN` model type, this parameter functions as a hyperparameter.
+                            The original SIREN work uses 30.
+                            You can bypass this parameter by providing input that are not normalized and larger then one.
+        input_multiplier  : float
+                            Initial value of the input multiplier before the very first layer.
+        model_type        : str
+                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                            `conventional` refers to a standard multi layer perceptron.
+                            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+        """
+        super(multi_layer_perceptron, self).__init__()
+        self.activation = activation
+        self.bias = bias
+        self.model_type = model_type
+        self.layers = torch.nn.ModuleList()
+        self.siren_multiplier = siren_multiplier
+        self.dimensions = dimensions
+        for i in range(len(self.dimensions) - 1):
+            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+        if not isinstance(input_multiplier, type(None)):
+            self.input_multiplier = torch.nn.ParameterList()
+            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+        if self.model_type == 'FILM SIREN':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+        if self.model_type == 'Gaussian':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        if hasattr(self, 'input_multiplier'):
+            result = x * self.input_multiplier[0]
+        else:
+            result = x
+        for layer_id, layer in enumerate(self.layers[:-1]):
+            result = layer(result)
+            if self.model_type == 'conventional':
+                result = self.activation(result)
+            elif self.model_type == 'swish':
+                resutl = swish(result)
+            elif self.model_type == 'SIREN':
+                result = torch.sin(result * self.siren_multiplier)
+            elif self.model_type == 'FILM SIREN':
+                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+            elif self.model_type == 'Gaussian': 
+                result = gaussian(result, self.alpha[layer_id][0])
+        result = self.layers[-1](result)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None) + +

+ + +
+ + + +

Parameters:

+
    +
  • + dimensions + – +
    +
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
    +
    +
    +
  • +
  • + activation + – +
    +
                Nonlinear activation function.
    +            Default is `torch.nn.ReLU()`.
    +
    +
    +
  • +
  • + bias + – +
    +
                If set to True, linear layers will include biases.
    +
    +
    +
  • +
  • + siren_multiplier + – +
    +
                When using `SIREN` model type, this parameter functions as a hyperparameter.
    +            The original SIREN work uses 30.
    +            You can bypass this parameter by providing input that are not normalized and larger then one.
    +
    +
    +
  • +
  • + input_multiplier + – +
    +
                Initial value of the input multiplier before the very first layer.
    +
    +
    +
  • +
  • + model_type + – +
    +
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
    +            `conventional` refers to a standard multi layer perceptron.
    +            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
    +            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
    +            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
    +            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
def __init__(self,
+             dimensions,
+             activation = torch.nn.ReLU(),
+             bias = False,
+             model_type = 'conventional',
+             siren_multiplier = 1.,
+             input_multiplier = None
+            ):
+    """
+    Parameters
+    ----------
+    dimensions        : list
+                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+    activation        : torch.nn
+                        Nonlinear activation function.
+                        Default is `torch.nn.ReLU()`.
+    bias              : bool
+                        If set to True, linear layers will include biases.
+    siren_multiplier  : float
+                        When using `SIREN` model type, this parameter functions as a hyperparameter.
+                        The original SIREN work uses 30.
+                        You can bypass this parameter by providing input that are not normalized and larger then one.
+    input_multiplier  : float
+                        Initial value of the input multiplier before the very first layer.
+    model_type        : str
+                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                        `conventional` refers to a standard multi layer perceptron.
+                        For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                        For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+    """
+    super(multi_layer_perceptron, self).__init__()
+    self.activation = activation
+    self.bias = bias
+    self.model_type = model_type
+    self.layers = torch.nn.ModuleList()
+    self.siren_multiplier = siren_multiplier
+    self.dimensions = dimensions
+    for i in range(len(self.dimensions) - 1):
+        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+    if not isinstance(input_multiplier, type(None)):
+        self.input_multiplier = torch.nn.ParameterList()
+        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+    if self.model_type == 'FILM SIREN':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+    if self.model_type == 'Gaussian':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    if hasattr(self, 'input_multiplier'):
+        result = x * self.input_multiplier[0]
+    else:
+        result = x
+    for layer_id, layer in enumerate(self.layers[:-1]):
+        result = layer(result)
+        if self.model_type == 'conventional':
+            result = self.activation(result)
+        elif self.model_type == 'swish':
+            resutl = swish(result)
+        elif self.model_type == 'SIREN':
+            result = torch.sin(result * self.siren_multiplier)
+        elif self.model_type == 'FILM SIREN':
+            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+        elif self.model_type == 'Gaussian': 
+            result = gaussian(result, self.alpha[layer_id][0])
+    result = self.layers[-1](result)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ non_local_layer + + +

+ + +
+

+ Bases: Module

+ + +

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class non_local_layer(torch.nn.Module):
+    """
+    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)
+    """
+    def __init__(
+                 self,
+                 input_channels = 1024,
+                 bottleneck_channels = 512,
+                 kernel_size = 1,
+                 bias = False,
+                ):
+        """
+
+        Parameters
+        ----------
+        input_channels      : int
+                              Number of input channels.
+        bottleneck_channels : int
+                              Number of middle channels.
+        kernel_size         : int
+                              Kernel size.
+        bias                : bool 
+                              Set to True to let convolutional layers have bias term.
+        """
+        super(non_local_layer, self).__init__()
+        self.input_channels = input_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.g = torch.nn.Conv2d(
+                                 self.input_channels, 
+                                 self.bottleneck_channels,
+                                 kernel_size = kernel_size,
+                                 padding = kernel_size // 2,
+                                 bias = bias
+                                )
+        self.W_z = torch.nn.Sequential(
+                                       torch.nn.Conv2d(
+                                                       self.bottleneck_channels,
+                                                       self.input_channels, 
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       padding = kernel_size // 2
+                                                      ),
+                                       torch.nn.BatchNorm2d(self.input_channels)
+                                      )
+        torch.nn.init.constant_(self.W_z[1].weight, 0)   
+        torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+    def forward(self, x):
+        """
+        Forward model [zi = Wzyi + xi]
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          First input data.                       
+
+
+        Returns
+        ----------
+        z               : torch.tensor
+                          Estimated output.
+        """
+        batch_size, channels, height, width = x.size()
+        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+        attn = torch.nn.functional.softmax(attn, dim=-1)
+        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+        W_y = self.W_z(y)
+        z = W_y + x
+        return z
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + input_channels + – +
    +
                  Number of input channels.
    +
    +
    +
  • +
  • + bottleneck_channels + (int, default: + 512 +) + – +
    +
                  Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
                  Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
                  Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 1024,
+             bottleneck_channels = 512,
+             kernel_size = 1,
+             bias = False,
+            ):
+    """
+
+    Parameters
+    ----------
+    input_channels      : int
+                          Number of input channels.
+    bottleneck_channels : int
+                          Number of middle channels.
+    kernel_size         : int
+                          Kernel size.
+    bias                : bool 
+                          Set to True to let convolutional layers have bias term.
+    """
+    super(non_local_layer, self).__init__()
+    self.input_channels = input_channels
+    self.bottleneck_channels = bottleneck_channels
+    self.g = torch.nn.Conv2d(
+                             self.input_channels, 
+                             self.bottleneck_channels,
+                             kernel_size = kernel_size,
+                             padding = kernel_size // 2,
+                             bias = bias
+                            )
+    self.W_z = torch.nn.Sequential(
+                                   torch.nn.Conv2d(
+                                                   self.bottleneck_channels,
+                                                   self.input_channels, 
+                                                   kernel_size = kernel_size,
+                                                   bias = bias,
+                                                   padding = kernel_size // 2
+                                                  ),
+                                   torch.nn.BatchNorm2d(self.input_channels)
+                                  )
+    torch.nn.init.constant_(self.W_z[1].weight, 0)   
+    torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model [zi = Wzyi + xi]

+ + +

Parameters:

+
    +
  • + x + – +
    +
              First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +z ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model [zi = Wzyi + xi]
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      First input data.                       
+
+
+    Returns
+    ----------
+    z               : torch.tensor
+                      Estimated output.
+    """
+    batch_size, channels, height, width = x.size()
+    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+    attn = torch.nn.functional.softmax(attn, dim=-1)
+    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+    W_y = self.W_z(y)
+    z = W_y + x
+    return z
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ normalization + + +

+ + +
+

+ Bases: Module

+ + +

A normalization layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class normalization(torch.nn.Module):
+    """
+    A normalization layer.
+    """
+    def __init__(
+                 self,
+                 dim = 1,
+                ):
+        """
+        Normalization layer.
+
+
+        Parameters
+        ----------
+        dim             : int
+                          Dimension (axis) to normalize.
+        """
+        super().__init__()
+        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+        mean = torch.mean(x, dim = 1, keepdim = True)
+        result =  (x - mean) * (var + eps).rsqrt() * self.k
+        return result 
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dim=1) + +

+ + +
+ +

Normalization layer.

+ + +

Parameters:

+
    +
  • + dim + – +
    +
              Dimension (axis) to normalize.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             dim = 1,
+            ):
+    """
+    Normalization layer.
+
+
+    Parameters
+    ----------
+    dim             : int
+                      Dimension (axis) to normalize.
+    """
+    super().__init__()
+    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+    mean = torch.mean(x, dim = 1, keepdim = True)
+    result =  (x - mean) * (var + eps).rsqrt() * self.k
+    return result 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ positional_encoder + + +

+ + +
+

+ Bases: Module

+ + +

A positional encoder module.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class positional_encoder(torch.nn.Module):
+    """
+    A positional encoder module.
+    """
+
+    def __init__(self, L):
+        """
+        A positional encoder module.
+
+        Parameters
+        ----------
+        L                   : int
+                              Positional encoding level.
+        """
+        super(positional_encoder, self).__init__()
+        self.L = L
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          Input data.
+
+        Returns
+        ----------
+        result          : torch.tensor
+                          Result of the forward operation
+        """
+        B, C = x.shape
+        x = x.view(B, C, 1)
+        results = [x]
+        for i in range(1, self.L + 1):
+            freq = (2 ** i) * math.pi
+            cos_x = torch.cos(freq * x)
+            sin_x = torch.sin(freq * x)
+            results.append(cos_x)
+            results.append(sin_x)
+        results = torch.cat(results, dim=2)
+        results = results.permute(0, 2, 1)
+        results = results.reshape(B, -1)
+        return results
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(L) + +

+ + +
+ +

A positional encoder module.

+ + +

Parameters:

+
    +
  • + L + – +
    +
                  Positional encoding level.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(self, L):
+    """
+    A positional encoder module.
+
+    Parameters
+    ----------
+    L                   : int
+                          Positional encoding level.
+    """
+    super(positional_encoder, self).__init__()
+    self.L = L
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
              Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      Input data.
+
+    Returns
+    ----------
+    result          : torch.tensor
+                      Result of the forward operation
+    """
+    B, C = x.shape
+    x = x.view(B, C, 1)
+    results = [x]
+    for i in range(1, self.L + 1):
+        freq = (2 ** i) * math.pi
+        cos_x = torch.cos(freq * x)
+        sin_x = torch.sin(freq * x)
+        results.append(cos_x)
+        results.append(sin_x)
+    results = torch.cat(results, dim=2)
+    results = results.permute(0, 2, 1)
+    results = results.reshape(B, -1)
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_attention_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual block with an attention layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_attention_layer(torch.nn.Module):
+    """
+    A residual block with an attention layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 1,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        An attention layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int or optioal
+                          Number of input channels.
+        output_channels : int or optional
+                          Number of middle channels.
+        kernel_size     : int or optional
+                          Kernel size.
+        bias            : bool or optional
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn or optional
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution0 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.convolution1 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.final_layer = torch.nn.Sequential(
+                                               self.activation,
+                                               torch.nn.Conv2d(
+                                                               output_channels,
+                                                               output_channels,
+                                                               kernel_size = kernel_size,
+                                                               padding = kernel_size // 2,
+                                                               bias = bias
+                                                              )
+                                              )
+
+
+    def forward(self, x0, x1):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x0             : torch.tensor
+                         First input data.
+
+        x1             : torch.tensor
+                         Seconnd input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        y0 = self.convolution0(x0)
+        y1 = self.convolution1(x1)
+        y2 = torch.add(y0, y1)
+        result = self.final_layer(y2) * x0
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

An attention layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int or optional, default: + 2 +) + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 1,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    An attention layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int or optioal
+                      Number of input channels.
+    output_channels : int or optional
+                      Number of middle channels.
+    kernel_size     : int or optional
+                      Kernel size.
+    bias            : bool or optional
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn or optional
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution0 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.convolution1 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.final_layer = torch.nn.Sequential(
+                                           self.activation,
+                                           torch.nn.Conv2d(
+                                                           output_channels,
+                                                           output_channels,
+                                                           kernel_size = kernel_size,
+                                                           padding = kernel_size // 2,
+                                                           bias = bias
+                                                          )
+                                          )
+
+
+
+ +
+ +
+ + +

+ forward(x0, x1) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x0 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x1 + – +
    +
             Seconnd input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x0, x1):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x0             : torch.tensor
+                     First input data.
+
+    x1             : torch.tensor
+                     Seconnd input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    y0 = self.convolution0(x0)
+    y1 = self.convolution1(x1)
+    y2 = torch.add(y0, y1)
+    result = self.final_layer(y2) * x0
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_layer(torch.nn.Module):
+    """
+    A residual layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = 16,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of middle channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution = double_convolution(
+                                              input_channels,
+                                              mid_channels = mid_channels,
+                                              output_channels = input_channels,
+                                              kernel_size = kernel_size,
+                                              bias = bias,
+                                              activation = activation
+                                             )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        x0 = self.convolution(x)
+        return x + x0
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = 16,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of middle channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution = double_convolution(
+                                          input_channels,
+                                          mid_channels = mid_channels,
+                                          output_channels = input_channels,
+                                          kernel_size = kernel_size,
+                                          bias = bias,
+                                          activation = activation
+                                         )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    x0 = self.convolution(x)
+    return x + x0
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatial_gate + + +

+ + +
+

+ Bases: Module

+ + +

Spatial attention module that applies a convolution layer after channel pooling. +This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class spatial_gate(torch.nn.Module):
+    """
+    Spatial attention module that applies a convolution layer after channel pooling.
+    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
+    """
+    def __init__(self):
+        """
+        Initializes the spatial gate module.
+        """
+        super().__init__()
+        kernel_size = 7
+        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+    def channel_pool(self, x):
+        """
+        Applies max and average pooling on the channels.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input tensor.
+
+        Returns
+        -------
+        output        : torch.tensor
+                        Output tensor.
+        """
+        max_pool = torch.max(x, 1)[0].unsqueeze(1)
+        avg_pool = torch.mean(x, 1).unsqueeze(1)
+        output = torch.cat((max_pool, avg_pool), dim=1)
+        return output
+
+
+    def forward(self, x):
+        """
+        Forward pass of the SpatialGate module.
+
+        Applies spatial attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the SpatialGate module.
+
+        Returns
+        -------
+        scaled_x     : torch.tensor
+                       Output tensor after applying spatial attention.
+        """
+        x_compress = self.channel_pool(x)
+        x_out = self.spatial(x_compress)
+        scale = torch.sigmoid(x_out)
+        scaled_x = x * scale
+        return scaled_x
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the spatial gate module.

+ +
+ Source code in odak/learn/models/components.py +
def __init__(self):
+    """
+    Initializes the spatial gate module.
+    """
+    super().__init__()
+    kernel_size = 7
+    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+
+ +
+ +
+ + +

+ channel_pool(x) + +

+ + +
+ +

Applies max and average pooling on the channels.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def channel_pool(self, x):
+    """
+    Applies max and average pooling on the channels.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input tensor.
+
+    Returns
+    -------
+    output        : torch.tensor
+                    Output tensor.
+    """
+    max_pool = torch.max(x, 1)[0].unsqueeze(1)
+    avg_pool = torch.mean(x, 1).unsqueeze(1)
+    output = torch.cat((max_pool, avg_pool), dim=1)
+    return output
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the SpatialGate module.

+

Applies spatial attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the SpatialGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +scaled_x ( tensor +) – +
    +

    Output tensor after applying spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the SpatialGate module.
+
+    Applies spatial attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the SpatialGate module.
+
+    Returns
+    -------
+    scaled_x     : torch.tensor
+                   Output tensor after applying spatial attention.
+    """
+    x_compress = self.channel_pool(x)
+    x_out = self.spatial(x_compress)
+    scale = torch.sigmoid(x_out)
+    scaled_x = x * scale
+    return scaled_x
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive convolution layer.

+ + +
+ References +

C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." +C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." +C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_convolution(torch.nn.Module):
+    """
+    A spatially adaptive convolution layer.
+
+    References
+    ----------
+
+    C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions."
+    C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation."
+    C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn.Module
+                          Activation function to apply. If None, no activation is applied.
+        """
+        super(spatially_adaptive_convolution, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        sa_output          : torch.tensor
+                            Estimated output tensor.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Resize weight to match the input channels and kernel size
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply spatially varying kernels
+        sv_feature = input_feature * sv_kernel
+
+        # Perform matrix multiplication
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+        return sa_output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Activation function to apply. If None, no activation is applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn.Module
+                      Activation function to apply. If None, no activation is applied.
+    """
+    super(spatially_adaptive_convolution, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sa_output ( tensor +) – +
    +

    Estimated output tensor. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    sa_output          : torch.tensor
+                        Estimated output tensor.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Resize weight to match the input channels and kernel size
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply spatially varying kernels
+    sv_feature = input_feature * sv_kernel
+
+    # Perform matrix multiplication
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+    return sa_output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_module + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive module that combines learned spatially adaptive convolutions.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_module(torch.nn.Module):
+    """
+    A spatially adaptive module that combines learned spatially adaptive convolutions.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive module.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super(spatially_adaptive_module, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.weight_output_channels = self.output_channels - 1
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.weight_output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive module.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        output             : torch.tensor
+                            Combined output tensor from standard and spatially adaptive convolutions.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Apply sv_kernel to the input_feature
+        sv_feature = input_feature * sv_kernel
+
+        # Original spatially varying convolution output
+        sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                           1,
+                                                            1,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+        # Reshape weight for spatially adaptive convolution
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply si_kernel on sv convolution output
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+
+        # Combine the outputs and apply activation function
+        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive module.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive module.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super(spatially_adaptive_module, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.weight_output_channels = self.output_channels - 1
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.weight_output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Combined output tensor from standard and spatially adaptive convolutions. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive module.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    output             : torch.tensor
+                        Combined output tensor from standard and spatially adaptive convolutions.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Apply sv_kernel to the input_feature
+    sv_feature = input_feature * sv_kernel
+
+    # Original spatially varying convolution output
+    sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                       1,
+                                                        1,
+                                                        (x.size(-2) // self.stride),
+                                                        (x.size(-1) // self.stride)
+                                                       )
+
+    # Reshape weight for spatially adaptive convolution
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply si_kernel on sv convolution output
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+    # Combine the outputs and apply activation function
+    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_unet + + +

+ + +
+

+ Bases: Module

+ + +

Spatially varying U-Net model based on spatially adaptive convolution.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_adaptive_unet(torch.nn.Module):
+    """
+    Spatially varying U-Net model based on spatially adaptive convolution.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 depth=3,
+                 dimensions=8,
+                 input_channels=6,
+                 out_channels=6,
+                 kernel_size=3,
+                 bias=True,
+                 normalization=False,
+                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        out_channels   : int
+                         Number of output channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.out_channels = out_channels
+        self.inc = convolution_layer(
+                                     input_channels=input_channels,
+                                     output_channels=dimensions,
+                                     kernel_size=kernel_size,
+                                     bias=bias,
+                                     normalization=normalization,
+                                     activation=activation
+                                    )
+
+        self.encoder = torch.nn.ModuleList()
+        for i in range(self.depth + 1):  # Downsampling layers
+            down_in_channels = dimensions * (2 ** i)
+            down_out_channels = 2 * down_in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels=down_in_channels,
+                                                          mid_channels=down_in_channels,
+                                                          output_channels=down_in_channels,
+                                                          kernel_size=kernel_size,
+                                                          bias=bias,
+                                                          normalization=normalization,
+                                                          activation=activation
+                                                         )
+            sam = spatially_adaptive_module(
+                                            input_channels=down_in_channels,
+                                            output_channels=down_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                           )
+            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+        self.global_feature_module = torch.nn.ModuleList()
+        double_convolution_layer = double_convolution(
+                                                      input_channels=dimensions * (2 ** (depth + 1)),
+                                                      mid_channels=dimensions * (2 ** (depth + 1)),
+                                                      output_channels=dimensions * (2 ** (depth + 1)),
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        global_feature_layer = global_feature_module(
+                                                     input_channels=dimensions * (2 ** (depth + 1)),
+                                                     mid_channels=dimensions * (2 ** (depth + 1)),
+                                                     output_channels=dimensions * (2 ** (depth + 1)),
+                                                     kernel_size=kernel_size,
+                                                     bias=bias,
+                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                    )
+        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+        self.decoder = torch.nn.ModuleList()
+        for i in range(depth, -1, -1):
+            up_in_channels = dimensions * (2 ** (i + 1))
+            up_mid_channels = up_in_channels // 2
+            if i == 0:
+                up_out_channels = self.out_channels
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = torch.nn.Sequential(
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_mid_channels,
+                                      kernel_size=kernel_size,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=activation,
+                                     ),
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_out_channels,
+                                      kernel_size=1,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=None,
+                                     )
+                )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+            else:
+                up_out_channels = up_in_channels // 2
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = double_convolution(
+                                                input_channels=up_mid_channels,
+                                                mid_channels=up_mid_channels,
+                                                output_channels=up_out_channels,
+                                                kernel_size=kernel_size,
+                                                bias=bias,
+                                                normalization=normalization,
+                                                activation=activation,
+                                               )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, sv_kernel, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+
+        field     : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        target_field : torch.tensor
+                       Estimated output.
+                       Dimension: (1, 6, H, W)
+        """
+        x = self.inc(field)
+        downsampling_outputs = [x]
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer[0](downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+            downsampling_outputs.append(sam_output)
+        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+        downsampling_outputs.append(global_feature)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate(self.decoder):
+            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+            x_up = up_layer[1](x_up)
+        result = x_up
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 6 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + out_channels + – +
    +
             Number of output channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth=3,
+             dimensions=8,
+             input_channels=6,
+             out_channels=6,
+             kernel_size=3,
+             bias=True,
+             normalization=False,
+             activation=torch.nn.LeakyReLU(0.2, inplace=True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    out_channels   : int
+                     Number of output channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.out_channels = out_channels
+    self.inc = convolution_layer(
+                                 input_channels=input_channels,
+                                 output_channels=dimensions,
+                                 kernel_size=kernel_size,
+                                 bias=bias,
+                                 normalization=normalization,
+                                 activation=activation
+                                )
+
+    self.encoder = torch.nn.ModuleList()
+    for i in range(self.depth + 1):  # Downsampling layers
+        down_in_channels = dimensions * (2 ** i)
+        down_out_channels = 2 * down_in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels=down_in_channels,
+                                                      mid_channels=down_in_channels,
+                                                      output_channels=down_in_channels,
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        sam = spatially_adaptive_module(
+                                        input_channels=down_in_channels,
+                                        output_channels=down_out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                       )
+        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+    self.global_feature_module = torch.nn.ModuleList()
+    double_convolution_layer = double_convolution(
+                                                  input_channels=dimensions * (2 ** (depth + 1)),
+                                                  mid_channels=dimensions * (2 ** (depth + 1)),
+                                                  output_channels=dimensions * (2 ** (depth + 1)),
+                                                  kernel_size=kernel_size,
+                                                  bias=bias,
+                                                  normalization=normalization,
+                                                  activation=activation
+                                                 )
+    global_feature_layer = global_feature_module(
+                                                 input_channels=dimensions * (2 ** (depth + 1)),
+                                                 mid_channels=dimensions * (2 ** (depth + 1)),
+                                                 output_channels=dimensions * (2 ** (depth + 1)),
+                                                 kernel_size=kernel_size,
+                                                 bias=bias,
+                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                )
+    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+    self.decoder = torch.nn.ModuleList()
+    for i in range(depth, -1, -1):
+        up_in_channels = dimensions * (2 ** (i + 1))
+        up_mid_channels = up_in_channels // 2
+        if i == 0:
+            up_out_channels = self.out_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = torch.nn.Sequential(
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_mid_channels,
+                                  kernel_size=kernel_size,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=activation,
+                                 ),
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_out_channels,
+                                  kernel_size=1,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=None,
+                                 )
+            )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+        else:
+            up_out_channels = up_in_channels // 2
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels=up_mid_channels,
+                                            mid_channels=up_mid_channels,
+                                            output_channels=up_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            normalization=normalization,
+                                            activation=activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(sv_kernel, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + sv_kernel + (list of torch.tensor) + – +
    +
        Learned spatially varying kernels.
    +    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
    +    where C_i, H_i, and W_i represent the channel, height, and width
    +    of each feature at a certain scale.
    +
    +
    +
  • +
  • + field + – +
    +
        Input field data.
    +    Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +target_field ( tensor +) – +
    +

    Estimated output. +Dimension: (1, 6, H, W)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, sv_kernel, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+
+    field     : torch.tensor
+                Input field data.
+                Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    target_field : torch.tensor
+                   Estimated output.
+                   Dimension: (1, 6, H, W)
+    """
+    x = self.inc(field)
+    downsampling_outputs = [x]
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer[0](downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+        downsampling_outputs.append(sam_output)
+    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+    downsampling_outputs.append(global_feature)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate(self.decoder):
+        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+        x_up = up_layer[1](x_up)
+    result = x_up
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_varying_kernel_generation_model + + +

+ + +
+

+ Bases: Module

+ + +

Spatially_varying_kernel_generation_model revised from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Refer to: +J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_varying_kernel_generation_model(torch.nn.Module):
+    """
+    Spatially_varying_kernel_generation_model revised from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Refer to:
+    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.
+    """
+
+    def __init__(
+                 self,
+                 depth = 3,
+                 dimensions = 8,
+                 input_channels = 7,
+                 kernel_size = 3,
+                 bias = True,
+                 normalization = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.inc = convolution_layer(
+                                     input_channels = input_channels,
+                                     output_channels = dimensions,
+                                     kernel_size = kernel_size,
+                                     bias = bias,
+                                     normalization = normalization,
+                                     activation = activation
+                                    )
+        self.encoder = torch.nn.ModuleList()
+        for i in range(depth + 1):  # downsampling layers
+            if i == 0:
+                in_channels = dimensions * (2 ** i)
+                out_channels = dimensions * (2 ** i)
+            elif i == depth:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = dimensions * (2 ** (i - 1))
+            else:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = 2 * in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels = in_channels,
+                                                          mid_channels = in_channels,
+                                                          output_channels = out_channels,
+                                                          kernel_size = kernel_size,
+                                                          bias = bias,
+                                                          normalization = normalization,
+                                                          activation = activation
+                                                         )
+            self.encoder.append(pooling_layer)
+            self.encoder.append(double_convolution_layer)
+        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+        for i in range(depth, -1, -1):
+            if i == 1:
+                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+            else:
+                svf_in_channels = 2 ** (self.depth + i) + 1
+            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+            svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+            spatially_varying_kernel_generation = torch.nn.ModuleList()
+            for j in range(i, -1, -1):
+                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+                spatially_varying_kernel_generation.append(pooling_layer)
+            kernel_generation_block = torch.nn.Sequential(
+                torch.nn.Conv2d(
+                                in_channels = svf_in_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_out_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+            )
+            spatially_varying_kernel_generation.append(kernel_generation_block)
+            self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+        self.decoder = torch.nn.ModuleList()
+        global_feature_layer = global_feature_module(  # global feature layer
+                                                     input_channels = dimensions * (2 ** (depth - 1)),
+                                                     mid_channels = dimensions * (2 ** (depth - 1)),
+                                                     output_channels = dimensions * (2 ** (depth - 1)),
+                                                     kernel_size = kernel_size,
+                                                     bias = bias,
+                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                    )
+        self.decoder.append(global_feature_layer)
+        for i in range(depth, 0, -1):
+            if i == 2:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels
+                up_mid_channels = up_in_channels
+            elif i == 1:
+                up_in_channels = dimensions * 2
+                up_out_channels = dimensions
+                up_mid_channels = up_out_channels
+            else:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels // 2
+                up_mid_channels = up_in_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels = up_in_channels,
+                                                            output_channels = up_mid_channels,
+                                                            kernel_size = 2,
+                                                            stride = 2,
+                                                            bias = bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels = up_mid_channels,
+                                            output_channels = up_out_channels,
+                                            kernel_size = kernel_size,
+                                            bias = bias,
+                                            normalization = normalization,
+                                            activation = activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, focal_surface, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        focal_surface : torch.tensor
+                        Input focal surface data.
+                        Dimension: (1, 1, H, W)
+
+        field         : torch.tensor
+                        Input field data.
+                        Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+        """
+        x = self.inc(torch.cat((focal_surface, field), dim = 1))
+        downsampling_outputs = [focal_surface]
+        downsampling_outputs.append(x)
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        sv_kernels = []
+        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+            if i == 0:
+                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+                downsampling_outputs[-1] = global_feature
+                sv_feature = [global_feature, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                              sv_feature[3]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+            else:
+                x_up = up_layer[0](downsampling_outputs[-1],
+                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+                x_up = up_layer[1](x_up)
+                downsampling_outputs[-1] = x_up
+                sv_feature = [x_up, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                if i == 1:
+                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+        return sv_kernels
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 7 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth = 3,
+             dimensions = 8,
+             input_channels = 7,
+             kernel_size = 3,
+             bias = True,
+             normalization = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.inc = convolution_layer(
+                                 input_channels = input_channels,
+                                 output_channels = dimensions,
+                                 kernel_size = kernel_size,
+                                 bias = bias,
+                                 normalization = normalization,
+                                 activation = activation
+                                )
+    self.encoder = torch.nn.ModuleList()
+    for i in range(depth + 1):  # downsampling layers
+        if i == 0:
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** i)
+        elif i == depth:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = dimensions * (2 ** (i - 1))
+        else:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = 2 * in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels = in_channels,
+                                                      mid_channels = in_channels,
+                                                      output_channels = out_channels,
+                                                      kernel_size = kernel_size,
+                                                      bias = bias,
+                                                      normalization = normalization,
+                                                      activation = activation
+                                                     )
+        self.encoder.append(pooling_layer)
+        self.encoder.append(double_convolution_layer)
+    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+    for i in range(depth, -1, -1):
+        if i == 1:
+            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+        else:
+            svf_in_channels = 2 ** (self.depth + i) + 1
+        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+        svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+        spatially_varying_kernel_generation = torch.nn.ModuleList()
+        for j in range(i, -1, -1):
+            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+            spatially_varying_kernel_generation.append(pooling_layer)
+        kernel_generation_block = torch.nn.Sequential(
+            torch.nn.Conv2d(
+                            in_channels = svf_in_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_out_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+        )
+        spatially_varying_kernel_generation.append(kernel_generation_block)
+        self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+    self.decoder = torch.nn.ModuleList()
+    global_feature_layer = global_feature_module(  # global feature layer
+                                                 input_channels = dimensions * (2 ** (depth - 1)),
+                                                 mid_channels = dimensions * (2 ** (depth - 1)),
+                                                 output_channels = dimensions * (2 ** (depth - 1)),
+                                                 kernel_size = kernel_size,
+                                                 bias = bias,
+                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                )
+    self.decoder.append(global_feature_layer)
+    for i in range(depth, 0, -1):
+        if i == 2:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels
+            up_mid_channels = up_in_channels
+        elif i == 1:
+            up_in_channels = dimensions * 2
+            up_out_channels = dimensions
+            up_mid_channels = up_out_channels
+        else:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels // 2
+            up_mid_channels = up_in_channels
+        upsample_layer = upsample_convtranspose2d_layer(
+                                                        input_channels = up_in_channels,
+                                                        output_channels = up_mid_channels,
+                                                        kernel_size = 2,
+                                                        stride = 2,
+                                                        bias = bias,
+                                                       )
+        conv_layer = double_convolution(
+                                        input_channels = up_mid_channels,
+                                        output_channels = up_out_channels,
+                                        kernel_size = kernel_size,
+                                        bias = bias,
+                                        normalization = normalization,
+                                        activation = activation,
+                                       )
+        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(focal_surface, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + focal_surface + (tensor) + – +
    +
            Input focal surface data.
    +        Dimension: (1, 1, H, W)
    +
    +
    +
  • +
  • + field + – +
    +
            Input field data.
    +        Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sv_kernel ( list of torch.tensor +) – +
    +

    Learned spatially varying kernels. +Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), +where C_i, H_i, and W_i represent the channel, height, and width +of each feature at a certain scale.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, focal_surface, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    focal_surface : torch.tensor
+                    Input focal surface data.
+                    Dimension: (1, 1, H, W)
+
+    field         : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+    """
+    x = self.inc(torch.cat((focal_surface, field), dim = 1))
+    downsampling_outputs = [focal_surface]
+    downsampling_outputs.append(x)
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    sv_kernels = []
+    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+        if i == 0:
+            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+            downsampling_outputs[-1] = global_feature
+            sv_feature = [global_feature, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                          sv_feature[3]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+        else:
+            x_up = up_layer[0](downsampling_outputs[-1],
+                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+            x_up = up_layer[1](x_up)
+            downsampling_outputs[-1] = x_up
+            sv_feature = [x_up, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            if i == 1:
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+    return sv_kernels
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ unet + + +

+ + +
+

+ Bases: Module

+ + +

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class unet(torch.nn.Module):
+    """
+    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.
+    """
+
+    def __init__(
+                 self, 
+                 depth = 4,
+                 dimensions = 64, 
+                 input_channels = 2, 
+                 output_channels = 1, 
+                 bilinear = False,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(inplace = True),
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth             : int
+                            Number of upsampling and downsampling
+        dimensions        : int
+                            Number of dimensions.
+        input_channels    : int
+                            Number of input channels.
+        output_channels   : int
+                            Number of output channels.
+        bilinear          : bool
+                            Uses bilinear upsampling in upsampling layers when set True.
+        bias              : bool
+                            Set True to let convolutional layers learn a bias term.
+        activation        : torch.nn
+                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+        """
+        super(unet, self).__init__()
+        self.inc = double_convolution(
+                                      input_channels = input_channels,
+                                      mid_channels = dimensions,
+                                      output_channels = dimensions,
+                                      kernel_size = kernel_size,
+                                      bias = bias,
+                                      activation = activation
+                                     )      
+
+        self.downsampling_layers = torch.nn.ModuleList()
+        self.upsampling_layers = torch.nn.ModuleList()
+        for i in range(depth): # downsampling layers
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** (i + 1))
+            down_layer = downsample_layer(in_channels,
+                                            out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                            )
+            self.downsampling_layers.append(down_layer)      
+
+        for i in range(depth - 1, -1, -1):  # upsampling layers
+            up_in_channels = dimensions * (2 ** (i + 1))  
+            up_out_channels = dimensions * (2 ** i) 
+            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+            self.upsampling_layers.append(up_layer)
+        self.outc = torch.nn.Conv2d(
+                                    dimensions, 
+                                    output_channels,
+                                    kernel_size = kernel_size,
+                                    padding = kernel_size // 2,
+                                    bias = bias
+                                   )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        downsampling_outputs = [self.inc(x)]
+        for down_layer in self.downsampling_layers:
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate((self.upsampling_layers)):
+            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+        result = self.outc(x_up)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
                Number of upsampling and downsampling
    +
    +
    +
  • +
  • + dimensions + – +
    +
                Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + – +
    +
                Number of input channels.
    +
    +
    +
  • +
  • + output_channels + – +
    +
                Number of output channels.
    +
    +
    +
  • +
  • + bilinear + – +
    +
                Uses bilinear upsampling in upsampling layers when set True.
    +
    +
    +
  • +
  • + bias + – +
    +
                Set True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self, 
+             depth = 4,
+             dimensions = 64, 
+             input_channels = 2, 
+             output_channels = 1, 
+             bilinear = False,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(inplace = True),
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth             : int
+                        Number of upsampling and downsampling
+    dimensions        : int
+                        Number of dimensions.
+    input_channels    : int
+                        Number of input channels.
+    output_channels   : int
+                        Number of output channels.
+    bilinear          : bool
+                        Uses bilinear upsampling in upsampling layers when set True.
+    bias              : bool
+                        Set True to let convolutional layers learn a bias term.
+    activation        : torch.nn
+                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+    """
+    super(unet, self).__init__()
+    self.inc = double_convolution(
+                                  input_channels = input_channels,
+                                  mid_channels = dimensions,
+                                  output_channels = dimensions,
+                                  kernel_size = kernel_size,
+                                  bias = bias,
+                                  activation = activation
+                                 )      
+
+    self.downsampling_layers = torch.nn.ModuleList()
+    self.upsampling_layers = torch.nn.ModuleList()
+    for i in range(depth): # downsampling layers
+        in_channels = dimensions * (2 ** i)
+        out_channels = dimensions * (2 ** (i + 1))
+        down_layer = downsample_layer(in_channels,
+                                        out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                        )
+        self.downsampling_layers.append(down_layer)      
+
+    for i in range(depth - 1, -1, -1):  # upsampling layers
+        up_in_channels = dimensions * (2 ** (i + 1))  
+        up_out_channels = dimensions * (2 ** i) 
+        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+        self.upsampling_layers.append(up_layer)
+    self.outc = torch.nn.Conv2d(
+                                dimensions, 
+                                output_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    downsampling_outputs = [self.inc(x)]
+    for down_layer in self.downsampling_layers:
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate((self.upsampling_layers)):
+        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+    result = self.outc(x_up)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_convtranspose2d_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convtranspose2d layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_convtranspose2d_layer(torch.nn.Module):
+    """
+    An upsampling convtranspose2d layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 2,
+                 stride = 2,
+                 bias = False,
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        """
+        super().__init__()
+        self.up = torch.nn.ConvTranspose2d(
+                                           in_channels = input_channels,
+                                           out_channels = output_channels,
+                                           bias = bias,
+                                           kernel_size = kernel_size,
+                                           stride = stride
+                                          )
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        result = x1 + x2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 2,
+             stride = 2,
+             bias = False,
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    """
+    super().__init__()
+    self.up = torch.nn.ConvTranspose2d(
+                                       in_channels = input_channels,
+                                       out_channels = output_channels,
+                                       bias = bias,
+                                       kernel_size = kernel_size,
+                                       stride = stride
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    result = x1 + x2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convolutional layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_layer(torch.nn.Module):
+    """
+    An upsampling convolutional layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(),
+                 bilinear = True
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        bilinear        : bool
+                          If set to True, bilinear sampling is used.
+        """
+        super(upsample_layer, self).__init__()
+        if bilinear:
+            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+            self.conv = double_convolution(
+                                           input_channels = input_channels + output_channels,
+                                           mid_channels = input_channels // 2,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+        else:
+            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+            self.conv = double_convolution(
+                                           input_channels = input_channels,
+                                           mid_channels = output_channels,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """ 
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        x = torch.cat([x2, x1], dim = 1)
+        result = self.conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
  • + bilinear + – +
    +
              If set to True, bilinear sampling is used.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(),
+             bilinear = True
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    bilinear        : bool
+                      If set to True, bilinear sampling is used.
+    """
+    super(upsample_layer, self).__init__()
+    if bilinear:
+        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+        self.conv = double_convolution(
+                                       input_channels = input_channels + output_channels,
+                                       mid_channels = input_channels // 2,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+    else:
+        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+        self.conv = double_convolution(
+                                       input_channels = input_channels,
+                                       mid_channels = output_channels,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """ 
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    x = torch.cat([x2, x1], dim = 1)
+    result = self.conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ gaussian(x, multiplier=1.0) + +

+ + +
+ +

A Gaussian non-linear activation. +For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input data.
    +
    +
    +
  • +
  • + multiplier + – +
    +
           Multiplier.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( float or tensor +) – +
    +

    Ouput data.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def gaussian(x, multiplier = 1.):
+    """
+    A Gaussian non-linear activation.
+    For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+
+    Parameters
+    ----------
+    x            : float or torch.tensor
+                   Input data.
+    multiplier   : float or torch.tensor
+                   Multiplier.
+
+    Returns
+    -------
+    result       : float or torch.tensor
+                   Ouput data.
+    """
+    result = torch.exp(- (multiplier * x) ** 2)
+    return result
+
+
+
+ +
+ +
+ + +

+ swish(x) + +

+ + +
+ +

A swish non-linear activation. +For more details: https://en.wikipedia.org/wiki/Swish_function

+ + +

Parameters:

+
    +
  • + x + – +
    +
             Input.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +out ( float or tensor +) – +
    +

    Output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def swish(x):
+    """
+    A swish non-linear activation.
+    For more details: https://en.wikipedia.org/wiki/Swish_function
+
+    Parameters
+    -----------
+    x              : float or torch.tensor
+                     Input.
+
+    Returns
+    -------
+    out            : float or torch.tensor
+                     Output.
+    """
+    out = x * torch.sigmoid(x)
+    return out
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_perception/index.html b/odak/learn_perception/index.html new file mode 100644 index 00000000..8db02c0e --- /dev/null +++ b/odak/learn_perception/index.html @@ -0,0 +1,29153 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.perception - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.perception

+ +
+ + + + +
+ +

odak.learn.perception

+

Defines a number of different perceptual loss functions, which can be used to optimise images where gaze location is known.

+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BlurLoss + + +

+ + +
+ + +

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

+

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

+

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/blur_loss.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
class BlurLoss():
+    """ 
+
+    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.
+
+    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.
+
+    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device("cpu"),
+                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
+        """
+        Parameters
+        ----------
+
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        blur_source             : bool
+                                    If true, blurs the source image as well as the target before computing the loss.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.device = device
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.mode = mode
+        self.blur = None
+        self.loss_func = torch.nn.MSELoss()
+        self.blur_source = blur_source
+        self.equi = equi
+
+    def blur_image(self, image, gaze):
+        if self.blur is None:
+            self.blur = RadiallyVaryingBlur()
+        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)
+
+    def __call__(self, image, target, gaze=[0.5, 0.5]):
+        """ 
+        Calculates the Blur Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze                : list
+                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("BlurLoss", image, target)
+        blurred_target = self.blur_image(target, gaze)
+        if self.blur_source:
+            blurred_image = self.blur_image(image, gaze)
+            return self.loss_func(blurred_image, blurred_target)
+        else:
+            return self.loss_func(image, blurred_target)
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5]) + +

+ + +
+ +

Calculates the Blur Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/blur_loss.py +
59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
def __call__(self, image, target, gaze=[0.5, 0.5]):
+    """ 
+    Calculates the Blur Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze                : list
+                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("BlurLoss", image, target)
+    blurred_target = self.blur_image(target, gaze)
+    if self.blur_source:
+        blurred_image = self.blur_image(image, gaze)
+        return self.loss_func(blurred_image, blurred_target)
+    else:
+        return self.loss_func(image, blurred_target)
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + blur_source + – +
    +
                        If true, blurs the source image as well as the target before computing the loss.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/blur_loss.py +
18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def __init__(self, device=torch.device("cpu"),
+             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
+    """
+    Parameters
+    ----------
+
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    blur_source             : bool
+                                If true, blurs the source image as well as the target before computing the loss.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.device = device
+    self.alpha = alpha
+    self.real_image_width = real_image_width
+    self.real_viewing_distance = real_viewing_distance
+    self.mode = mode
+    self.blur = None
+    self.loss_func = torch.nn.MSELoss()
+    self.blur_source = blur_source
+    self.equi = equi
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ CVVDP + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
class CVVDP(nn.Module):
+    def __init__(self, device = torch.device('cpu')):
+        """
+        Initializes the CVVDP model with a specified device.
+
+        Parameters
+        ----------
+        device   : torch.device
+                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+        """
+        super(CVVDP, self).__init__()
+        try:
+            import pycvvdp
+            self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
+        except Exception as e:
+            logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets, dim_order = 'CHW'):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets    h  : torch.tensor
+                        The ground truth images.
+        dim_order     : str
+                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+        Returns
+        -------
+        result        : torch.tensor
+                        The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
+            return l_ColorVideoVDP
+        except Exception as e:
+            logging.warning('ColorVideoVDP failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(device=torch.device('cpu')) + +

+ + +
+ +

Initializes the CVVDP model with a specified device.

+ + +

Parameters:

+
    +
  • + device + – +
    +
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
def __init__(self, device = torch.device('cpu')):
+    """
+    Initializes the CVVDP model with a specified device.
+
+    Parameters
+    ----------
+    device   : torch.device
+                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+    """
+    super(CVVDP, self).__init__()
+    try:
+        import pycvvdp
+        self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
+    except Exception as e:
+        logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets, dim_order='CHW') + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
  • + dim_order + – +
    +
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def forward(self, predictions, targets, dim_order = 'CHW'):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets    h  : torch.tensor
+                    The ground truth images.
+    dim_order     : str
+                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+    Returns
+    -------
+    result        : torch.tensor
+                    The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
+        return l_ColorVideoVDP
+    except Exception as e:
+        logging.warning('ColorVideoVDP failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ FVVDP + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
class FVVDP(nn.Module):
+    def __init__(self, device = torch.device('cpu')):
+        """
+        Initializes the FVVDP model with a specified device.
+
+        Parameters
+        ----------
+        device   : torch.device
+                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+        """
+        super(FVVDP, self).__init__()
+        try:
+            import pyfvvdp
+            self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
+        except Exception as e:
+            logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets, dim_order = 'CHW'):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets       : torch.tensor
+                        The ground truth images.
+        dim_order     : str
+                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+        Returns
+        -------
+        result        : torch.tensor
+                          The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
+            return l_FovVideoVDP
+        except Exception as e:
+            logging.warning('FovVideoVDP failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(device=torch.device('cpu')) + +

+ + +
+ +

Initializes the FVVDP model with a specified device.

+ + +

Parameters:

+
    +
  • + device + – +
    +
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
def __init__(self, device = torch.device('cpu')):
+    """
+    Initializes the FVVDP model with a specified device.
+
+    Parameters
+    ----------
+    device   : torch.device
+                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+    """
+    super(FVVDP, self).__init__()
+    try:
+        import pyfvvdp
+        self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
+    except Exception as e:
+        logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets, dim_order='CHW') + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
  • + dim_order + – +
    +
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
def forward(self, predictions, targets, dim_order = 'CHW'):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets       : torch.tensor
+                    The ground truth images.
+    dim_order     : str
+                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+    Returns
+    -------
+    result        : torch.tensor
+                      The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
+        return l_FovVideoVDP
+    except Exception as e:
+        logging.warning('FovVideoVDP failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ LPIPS + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
class LPIPS(nn.Module):
+
+    def __init__(self):
+        """
+        Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.
+
+        """
+        super(LPIPS, self).__init__()
+        try:
+            import torchmetrics
+            self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
+        except Exception as e:
+            logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets       : torch.tensor
+                        The ground truth images.
+
+        Returns
+        -------
+        result        : torch.tensor
+                        The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            lpips_image = predictions
+            lpips_target = targets
+            if len(lpips_image.shape) == 3:
+                lpips_image = lpips_image.unsqueeze(0)
+                lpips_target = lpips_target.unsqueeze(0)
+            if lpips_image.shape[1] == 1:
+                lpips_image = lpips_image.repeat(1, 3, 1, 1)
+                lpips_target = lpips_target.repeat(1, 3, 1, 1)
+            lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
+            lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
+            l_LPIPS = self.lpips(lpips_image, lpips_target)
+            return l_LPIPS
+        except Exception as e:
+            logging.warning('LPIPS failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
def __init__(self):
+    """
+    Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.
+
+    """
+    super(LPIPS, self).__init__()
+    try:
+        import torchmetrics
+        self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
+    except Exception as e:
+        logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets       : torch.tensor
+                    The ground truth images.
+
+    Returns
+    -------
+    result        : torch.tensor
+                    The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        lpips_image = predictions
+        lpips_target = targets
+        if len(lpips_image.shape) == 3:
+            lpips_image = lpips_image.unsqueeze(0)
+            lpips_target = lpips_target.unsqueeze(0)
+        if lpips_image.shape[1] == 1:
+            lpips_image = lpips_image.repeat(1, 3, 1, 1)
+            lpips_target = lpips_target.repeat(1, 3, 1, 1)
+        lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
+        lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
+        l_LPIPS = self.lpips(lpips_image, lpips_target)
+        return l_LPIPS
+    except Exception as e:
+        logging.warning('LPIPS failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MSSSIM + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
class MSSSIM(nn.Module):
+    '''
+    A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(MSSSIM, self).__init__()
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions : torch.tensor
+                      The predicted images.
+        targets     : torch.tensor
+                      The ground truth images.
+
+        Returns
+        -------
+        result      : torch.tensor 
+                      The computed MS-SSIM value if successful, otherwise 0.0.
+        """
+        try:
+            from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
+            if len(predictions.shape) == 3:
+                predictions = predictions.unsqueeze(0)
+                targets = targets.unsqueeze(0)
+            l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
+            return l_MSSSIM  
+        except Exception as e:
+            logging.warning('MS-SSIM failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + (tensor) + – +
    +
          The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
          The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed MS-SSIM value if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions : torch.tensor
+                  The predicted images.
+    targets     : torch.tensor
+                  The ground truth images.
+
+    Returns
+    -------
+    result      : torch.tensor 
+                  The computed MS-SSIM value if successful, otherwise 0.0.
+    """
+    try:
+        from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
+        if len(predictions.shape) == 3:
+            predictions = predictions.unsqueeze(0)
+            targets = targets.unsqueeze(0)
+        l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
+        return l_MSSSIM  
+    except Exception as e:
+        logging.warning('MS-SSIM failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MetamerMSELoss + + +

+ + +
+ + +

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

+

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

+

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
class MetamerMSELoss():
+    """ 
+    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.
+
+    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.
+
+    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device("cpu"),
+                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
+                 n_pyramid_levels=5, n_orientations=2, equi=False):
+        """
+        Parameters
+        ----------
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        n_pyramid_levels        : int 
+                                    Number of levels of the steerable pyramid. Note that the image is padded
+                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                    too high will slow down the calculation a lot.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        n_orientations          : int 
+                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                    Increasing this will increase runtime.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.target_metamer = None
+        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
+                                            real_viewing_distance=real_viewing_distance,
+                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
+        self.loss_func = torch.nn.MSELoss()
+        self.noise = None
+
+    def gen_metamer(self, image, gaze):
+        """ 
+        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+        This function can be used on its own to generate a metamer for a desired image.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze    : list
+                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        metamer : torch.tensor
+                The generated metamer image
+        """
+        image = rgb_2_ycrcb(image)
+        image_size = image.size()
+        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+
+        target_stats = self.metameric_loss.calc_statsmaps(
+            image, gaze=gaze, alpha=self.metameric_loss.alpha)
+        target_means = target_stats[::2]
+        target_stdevs = target_stats[1::2]
+        if self.noise is None or self.noise.size() != image.size():
+            torch.manual_seed(0)
+            noise_image = torch.rand_like(image)
+        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+            noise_image, self.metameric_loss.n_pyramid_levels)
+        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+            image, self.metameric_loss.n_pyramid_levels)
+
+        def match_level(input_level, target_mean, target_std):
+            level = input_level.clone()
+            level -= torch.mean(level)
+            input_std = torch.sqrt(torch.mean(level * level))
+            eps = 1e-6
+            # Safeguard against divide by zero
+            input_std[input_std < eps] = eps
+            level /= input_std
+            level *= target_std
+            level += target_mean
+            return level
+
+        nbands = len(noise_pyramid[0]["b"])
+        noise_pyramid[0]["h"] = match_level(
+            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+        for l in range(len(noise_pyramid)-1):
+            for b in range(nbands):
+                noise_pyramid[l]["b"][b] = match_level(
+                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
+            noise_pyramid)
+        metamer = ycrcb_2_rgb(metamer)
+        # Crop to remove any padding
+        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+        return metamer
+
+    def __call__(self, image, target, gaze=[0.5, 0.5]):
+        """ 
+        Calculates the Metamer MSE Loss.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target  : torch.tensor
+                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze    : list
+                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamerMSELoss", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)
+
+        if target is not self.target or self.target is None:
+            self.target_metamer = self.gen_metamer(target, gaze)
+            self.target = target
+
+        return self.loss_func(image, self.target_metamer)
+
+    def to(self, device):
+        self.metameric_loss = self.metameric_loss.to(device)
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5]) + +

+ + +
+ +

Calculates the Metamer MSE Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
def __call__(self, image, target, gaze=[0.5, 0.5]):
+    """ 
+    Calculates the Metamer MSE Loss.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target  : torch.tensor
+            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze    : list
+            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamerMSELoss", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)
+
+    if target is not self.target or self.target is None:
+        self.target_metamer = self.gen_metamer(target, gaze)
+        self.target = target
+
+    return self.loss_func(image, self.target_metamer)
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                        Number of levels of the steerable pyramid. Note that the image is padded
    +                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                    too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                    Increasing this will increase runtime.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
def __init__(self, device=torch.device("cpu"),
+             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
+             n_pyramid_levels=5, n_orientations=2, equi=False):
+    """
+    Parameters
+    ----------
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    n_pyramid_levels        : int 
+                                Number of levels of the steerable pyramid. Note that the image is padded
+                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                too high will slow down the calculation a lot.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    n_orientations          : int 
+                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                Increasing this will increase runtime.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.target_metamer = None
+    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
+                                        real_viewing_distance=real_viewing_distance,
+                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
+    self.loss_func = torch.nn.MSELoss()
+    self.noise = None
+
+
+
+ +
+ +
+ + +

+ gen_metamer(image, gaze) + +

+ + +
+ +

Generates a metamer for an image, following the method in this paper +This function can be used on its own to generate a metamer for a desired image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +metamer ( tensor +) – +
    +

    The generated metamer image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
def gen_metamer(self, image, gaze):
+    """ 
+    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+    This function can be used on its own to generate a metamer for a desired image.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze    : list
+            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    metamer : torch.tensor
+            The generated metamer image
+    """
+    image = rgb_2_ycrcb(image)
+    image_size = image.size()
+    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+
+    target_stats = self.metameric_loss.calc_statsmaps(
+        image, gaze=gaze, alpha=self.metameric_loss.alpha)
+    target_means = target_stats[::2]
+    target_stdevs = target_stats[1::2]
+    if self.noise is None or self.noise.size() != image.size():
+        torch.manual_seed(0)
+        noise_image = torch.rand_like(image)
+    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+        noise_image, self.metameric_loss.n_pyramid_levels)
+    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+        image, self.metameric_loss.n_pyramid_levels)
+
+    def match_level(input_level, target_mean, target_std):
+        level = input_level.clone()
+        level -= torch.mean(level)
+        input_std = torch.sqrt(torch.mean(level * level))
+        eps = 1e-6
+        # Safeguard against divide by zero
+        input_std[input_std < eps] = eps
+        level /= input_std
+        level *= target_std
+        level += target_mean
+        return level
+
+    nbands = len(noise_pyramid[0]["b"])
+    noise_pyramid[0]["h"] = match_level(
+        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+    for l in range(len(noise_pyramid)-1):
+        for b in range(nbands):
+            noise_pyramid[l]["b"][b] = match_level(
+                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
+        noise_pyramid)
+    metamer = ycrcb_2_rgb(metamer)
+    # Crop to remove any padding
+    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+    return metamer
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MetamericLoss + + +

+ + +
+ + +

The MetamericLoss class provides a perceptual loss function.

+

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

+

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/metameric_loss.py +
class MetamericLoss():
+    """
+    The `MetamericLoss` class provides a perceptual loss function.
+
+    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.
+
+    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
+                 real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
+                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
+                 use_fullres_l0=False, equi=False):
+        """
+        Parameters
+        ----------
+
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        n_pyramid_levels        : int 
+                                    Number of levels of the steerable pyramid. Note that the image is padded
+                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                    too high will slow down the calculation a lot.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        n_orientations          : int 
+                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                    Increasing this will increase runtime.
+        use_l2_foveal_loss      : bool 
+                                    If true, for all the pixels that have pooling size 1 pixel in the 
+                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
+                                    In practice this gives better results when the loss is used for holography.
+        fovea_weight            : float 
+                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
+        use_radial_weight       : bool 
+                                    If True, will apply a radial weighting when calculating the difference between
+                                    the source and target stats maps. This weights stats closer to the fovea more than those
+                                    further away.
+        use_fullres_l0          : bool 
+                                    If true, stats for the lowpass residual are replaced with blurred versions
+                                    of the full-resolution source and target images.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.device = device
+        self.pyramid_maker = None
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.blurs = None
+        self.n_pyramid_levels = n_pyramid_levels
+        self.n_orientations = n_orientations
+        self.mode = mode
+        self.use_l2_foveal_loss = use_l2_foveal_loss
+        self.fovea_weight = fovea_weight
+        self.use_radial_weight = use_radial_weight
+        self.use_fullres_l0 = use_fullres_l0
+        self.equi = equi
+        if self.use_fullres_l0 and self.use_l2_foveal_loss:
+            raise Exception(
+                "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")
+
+    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,
+                       real_viewing_distance=0.6, mode="quadratic", equi=False):
+
+        if self.pyramid_maker is None or \
+                self.pyramid_maker.device != self.device or \
+                len(self.pyramid_maker.band_filters) != self.n_orientations or\
+                self.pyramid_maker.filt_h0.size(0) != image.size(1):
+            self.pyramid_maker = SpatialSteerablePyramid(
+                use_bilinear_downup=False, n_channels=image.size(1),
+                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)
+
+        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:
+            self.blurs = [RadiallyVaryingBlur()
+                          for i in range(self.n_pyramid_levels)]
+
+        def find_stats(image_pyr_level, blur):
+            image_means = blur.blur(
+                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
+            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,
+                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
+
+            image_vars = image_meansq - (image_means*image_means)
+            image_vars[image_vars < 1e-7] = 1e-7
+            image_std = torch.sqrt(image_vars)
+            if torch.any(torch.isnan(image_means)):
+                print(image_means)
+                raise Exception("NaN in image means!")
+            if torch.any(torch.isnan(image_std)):
+                print(image_std)
+                raise Exception("NaN in image stdevs!")
+            if self.use_fullres_l0:
+                mask = blur.lod_map > 1e-6
+                mask = mask[None, None, ...]
+                if image_means.size(1) > 1:
+                    mask = mask.repeat(1, image_means.size(1), 1, 1)
+                matte = torch.zeros_like(image_means)
+                matte[mask] = 1.0
+                return image_means * matte, image_std * matte
+            return image_means, image_std
+        output_stats = []
+        image_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])
+        if self.use_l2_foveal_loss:
+            self.fovea_mask = torch.zeros(image.size(), device=image.device)
+            for i in range(self.fovea_mask.size(1)):
+                self.fovea_mask[0, i, ...] = 1.0 - \
+                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))
+                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0
+            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)
+            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode="area")
+            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode="bilinear")
+            periphery_mask = 1.0 - self.fovea_mask
+            self.periphery_mask = periphery_mask.clone()
+            output_stats.append(means * periphery_mask)
+            output_stats.append(variances * periphery_mask)
+        else:
+            output_stats.append(means)
+            output_stats.append(variances)
+
+        for l in range(0, len(image_pyramid)-1):
+            for o in range(len(image_pyramid[l]['b'])):
+                means, variances = find_stats(
+                    image_pyramid[l]['b'][o], self.blurs[l])
+                if self.use_l2_foveal_loss:
+                    output_stats.append(means * periphery_mask)
+                    output_stats.append(variances * periphery_mask)
+                else:
+                    output_stats.append(means)
+                    output_stats.append(variances)
+            if self.use_l2_foveal_loss:
+                periphery_mask = torch.nn.functional.interpolate(
+                    periphery_mask, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+
+        if self.use_l2_foveal_loss:
+            output_stats.append(image_pyramid[-1]["l"] * periphery_mask)
+        elif self.use_fullres_l0:
+            output_stats.append(self.blurs[0].blur(
+                image, alpha, real_image_width, real_viewing_distance, gaze, mode))
+        else:
+            output_stats.append(image_pyramid[-1]["l"])
+        return output_stats
+
+    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):
+        loss = 0.0
+        for a, b in zip(statsmap_a, statsmap_b):
+            if self.use_radial_weight:
+                radii = make_radial_map(
+                    [a.size(-2), a.size(-1)], gaze).to(a.device)
+                weights = 1.1 - (radii * radii * radii * radii)
+                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)
+                loss += torch.nn.MSELoss()(weights*a, weights*b)
+            else:
+                loss += torch.nn.MSELoss()(a, b)
+        loss /= len(statsmap_a)
+        return loss
+
+    def visualise_loss_map(self, image_stats):
+        loss_map = torch.zeros(image_stats[0].size()[-2:])
+        for i in range(len(image_stats)):
+            stats = image_stats[i]
+            target_stats = self.target_stats[i]
+            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
+            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
+            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            loss_map += stat_mse_map[0, 0, ...]
+        self.loss_map = loss_map
+
+    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
+        """ 
+        Calculates the Metameric Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        image_colorspace    : str
+                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                                accepted values: RGB, YCrCb.
+        gaze                : list
+                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+        visualise_loss      : bool
+                                Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamericLoss", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+        # If input is RGB, convert to YCrCb.
+        if image.size(1) == 3 and image_colorspace == "RGB":
+            image = rgb_2_ycrcb(image)
+            target = rgb_2_ycrcb(target)
+        if self.target is None:
+            self.target = torch.zeros(target.shape).to(target.device)
+        if type(target) == type(self.target):
+            if not torch.all(torch.eq(target, self.target)):
+                self.target = target.detach().clone()
+                self.target_stats = self.calc_statsmaps(
+                    self.target,
+                    gaze=gaze,
+                    alpha=self.alpha,
+                    real_image_width=self.real_image_width,
+                    real_viewing_distance=self.real_viewing_distance,
+                    mode=self.mode
+                )
+                self.target = target.detach().clone()
+            image_stats = self.calc_statsmaps(
+                image,
+                gaze=gaze,
+                alpha=self.alpha,
+                real_image_width=self.real_image_width,
+                real_viewing_distance=self.real_viewing_distance,
+                mode=self.mode
+            )
+            if visualise_loss:
+                self.visualise_loss_map(image_stats)
+            if self.use_l2_foveal_loss:
+                peripheral_loss = self.metameric_loss_stats(
+                    image_stats, self.target_stats, gaze)
+                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
+                # New weighting - evenly weight fovea and periphery.
+                loss = peripheral_loss + self.fovea_weight * foveal_loss
+            else:
+                loss = self.metameric_loss_stats(
+                    image_stats, self.target_stats, gaze)
+            return loss
+        else:
+            raise Exception("Target of incorrect type")
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False) + +

+ + +
+ +

Calculates the Metameric Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + image_colorspace + – +
    +
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
    +                accepted values: RGB, YCrCb.
    +
    +
    +
  • +
  • + gaze + – +
    +
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
  • + visualise_loss + – +
    +
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss.py +
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
+    """ 
+    Calculates the Metameric Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    image_colorspace    : str
+                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                            accepted values: RGB, YCrCb.
+    gaze                : list
+                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+    visualise_loss      : bool
+                            Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamericLoss", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+    # If input is RGB, convert to YCrCb.
+    if image.size(1) == 3 and image_colorspace == "RGB":
+        image = rgb_2_ycrcb(image)
+        target = rgb_2_ycrcb(target)
+    if self.target is None:
+        self.target = torch.zeros(target.shape).to(target.device)
+    if type(target) == type(self.target):
+        if not torch.all(torch.eq(target, self.target)):
+            self.target = target.detach().clone()
+            self.target_stats = self.calc_statsmaps(
+                self.target,
+                gaze=gaze,
+                alpha=self.alpha,
+                real_image_width=self.real_image_width,
+                real_viewing_distance=self.real_viewing_distance,
+                mode=self.mode
+            )
+            self.target = target.detach().clone()
+        image_stats = self.calc_statsmaps(
+            image,
+            gaze=gaze,
+            alpha=self.alpha,
+            real_image_width=self.real_image_width,
+            real_viewing_distance=self.real_viewing_distance,
+            mode=self.mode
+        )
+        if visualise_loss:
+            self.visualise_loss_map(image_stats)
+        if self.use_l2_foveal_loss:
+            peripheral_loss = self.metameric_loss_stats(
+                image_stats, self.target_stats, gaze)
+            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
+            # New weighting - evenly weight fovea and periphery.
+            loss = peripheral_loss + self.fovea_weight * foveal_loss
+        else:
+            loss = self.metameric_loss_stats(
+                image_stats, self.target_stats, gaze)
+        return loss
+    else:
+        raise Exception("Target of incorrect type")
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                        Number of levels of the steerable pyramid. Note that the image is padded
    +                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                    too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                    Increasing this will increase runtime.
    +
    +
    +
  • +
  • + use_l2_foveal_loss + – +
    +
                        If true, for all the pixels that have pooling size 1 pixel in the 
    +                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
    +                    In practice this gives better results when the loss is used for holography.
    +
    +
    +
  • +
  • + fovea_weight + – +
    +
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    +
    +
    +
  • +
  • + use_radial_weight + – +
    +
                        If True, will apply a radial weighting when calculating the difference between
    +                    the source and target stats maps. This weights stats closer to the fovea more than those
    +                    further away.
    +
    +
    +
  • +
  • + use_fullres_l0 + – +
    +
                        If true, stats for the lowpass residual are replaced with blurred versions
    +                    of the full-resolution source and target images.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
+             real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
+             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
+             use_fullres_l0=False, equi=False):
+    """
+    Parameters
+    ----------
+
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    n_pyramid_levels        : int 
+                                Number of levels of the steerable pyramid. Note that the image is padded
+                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                too high will slow down the calculation a lot.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    n_orientations          : int 
+                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                Increasing this will increase runtime.
+    use_l2_foveal_loss      : bool 
+                                If true, for all the pixels that have pooling size 1 pixel in the 
+                                largest scale will use direct L2 against target rather than pooling over pyramid levels.
+                                In practice this gives better results when the loss is used for holography.
+    fovea_weight            : float 
+                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
+    use_radial_weight       : bool 
+                                If True, will apply a radial weighting when calculating the difference between
+                                the source and target stats maps. This weights stats closer to the fovea more than those
+                                further away.
+    use_fullres_l0          : bool 
+                                If true, stats for the lowpass residual are replaced with blurred versions
+                                of the full-resolution source and target images.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.device = device
+    self.pyramid_maker = None
+    self.alpha = alpha
+    self.real_image_width = real_image_width
+    self.real_viewing_distance = real_viewing_distance
+    self.blurs = None
+    self.n_pyramid_levels = n_pyramid_levels
+    self.n_orientations = n_orientations
+    self.mode = mode
+    self.use_l2_foveal_loss = use_l2_foveal_loss
+    self.fovea_weight = fovea_weight
+    self.use_radial_weight = use_radial_weight
+    self.use_fullres_l0 = use_fullres_l0
+    self.equi = equi
+    if self.use_fullres_l0 and self.use_l2_foveal_loss:
+        raise Exception(
+            "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MetamericLossUniform + + +

+ + +
+ + +

Measures metameric loss between a given image and a metamer of the given target image. +This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

+ + + + + + +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
class MetamericLossUniform():
+    """
+    Measures metameric loss between a given image and a metamer of the given target image.
+    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.
+    """
+
+    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
+        """
+
+        Parameters
+        ----------
+        pooling_size            : int
+                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
+        n_pyramid_levels        : int 
+                                  Number of levels of the steerable pyramid. Note that the image is padded
+                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                  too high will slow down the calculation a lot.
+        n_orientations          : int 
+                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                  Increasing this will increase runtime.
+
+        """
+        self.target = None
+        self.device = device
+        self.pyramid_maker = None
+        self.pooling_size = pooling_size
+        self.n_pyramid_levels = n_pyramid_levels
+        self.n_orientations = n_orientations
+
+    def calc_statsmaps(self, image, pooling_size):
+
+        if self.pyramid_maker is None or \
+                self.pyramid_maker.device != self.device or \
+                len(self.pyramid_maker.band_filters) != self.n_orientations or\
+                self.pyramid_maker.filt_h0.size(0) != image.size(1):
+            self.pyramid_maker = SpatialSteerablePyramid(
+                use_bilinear_downup=False, n_channels=image.size(1),
+                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)
+
+
+        def find_stats(image_pyr_level, pooling_size):
+            image_means = uniform_blur(image_pyr_level, pooling_size)
+            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)
+            image_vars = image_meansq - (image_means*image_means)
+            image_vars[image_vars < 1e-7] = 1e-7
+            image_std = torch.sqrt(image_vars)
+            if torch.any(torch.isnan(image_means)):
+                print(image_means)
+                raise Exception("NaN in image means!")
+            if torch.any(torch.isnan(image_std)):
+                print(image_std)
+                raise Exception("NaN in image stdevs!")
+            return image_means, image_std
+
+        output_stats = []
+        image_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+        curr_pooling_size = pooling_size
+        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)
+        output_stats.append(means)
+        output_stats.append(variances)
+
+        for l in range(0, len(image_pyramid)-1):
+            for o in range(len(image_pyramid[l]['b'])):
+                means, variances = find_stats(
+                    image_pyramid[l]['b'][o], curr_pooling_size)
+                output_stats.append(means)
+                output_stats.append(variances)
+            curr_pooling_size /= 2
+
+        output_stats.append(image_pyramid[-1]["l"])
+        return output_stats
+
+    def metameric_loss_stats(self, statsmap_a, statsmap_b):
+        loss = 0.0
+        for a, b in zip(statsmap_a, statsmap_b):
+            loss += torch.nn.MSELoss()(a, b)
+        loss /= len(statsmap_a)
+        return loss
+
+    def visualise_loss_map(self, image_stats):
+        loss_map = torch.zeros(image_stats[0].size()[-2:])
+        for i in range(len(image_stats)):
+            stats = image_stats[i]
+            target_stats = self.target_stats[i]
+            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
+            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
+            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            loss_map += stat_mse_map[0, 0, ...]
+        self.loss_map = loss_map
+
+    def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
+        """ 
+        Calculates the Metameric Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        image_colorspace    : str
+                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                                accepted values: RGB, YCrCb.
+        visualise_loss      : bool
+                                Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamericLossUniform", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+        # If input is RGB, convert to YCrCb.
+        if image.size(1) == 3 and image_colorspace == "RGB":
+            image = rgb_2_ycrcb(image)
+            target = rgb_2_ycrcb(target)
+        if self.target is None:
+            self.target = torch.zeros(target.shape).to(target.device)
+        if type(target) == type(self.target):
+            if not torch.all(torch.eq(target, self.target)):
+                self.target = target.detach().clone()
+                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
+                self.target = target.detach().clone()
+            image_stats = self.calc_statsmaps(image, self.pooling_size)
+
+            if visualise_loss:
+                self.visualise_loss_map(image_stats)
+            loss = self.metameric_loss_stats(
+                image_stats, self.target_stats)
+            return loss
+        else:
+            raise Exception("Target of incorrect type")
+
+    def gen_metamer(self, image):
+        """ 
+        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+        This function can be used on its own to generate a metamer for a desired image.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+
+        Returns
+        -------
+        metamer : torch.tensor
+                  The generated metamer image
+        """
+        image = rgb_2_ycrcb(image)
+        image_size = image.size()
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+
+        target_stats = self.calc_statsmaps(
+            image, self.pooling_size)
+        target_means = target_stats[::2]
+        target_stdevs = target_stats[1::2]
+        torch.manual_seed(0)
+        noise_image = torch.rand_like(image)
+        noise_pyramid = self.pyramid_maker.construct_pyramid(
+            noise_image, self.n_pyramid_levels)
+        input_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+
+        def match_level(input_level, target_mean, target_std):
+            level = input_level.clone()
+            level -= torch.mean(level)
+            input_std = torch.sqrt(torch.mean(level * level))
+            eps = 1e-6
+            # Safeguard against divide by zero
+            input_std[input_std < eps] = eps
+            level /= input_std
+            level *= target_std
+            level += target_mean
+            return level
+
+        nbands = len(noise_pyramid[0]["b"])
+        noise_pyramid[0]["h"] = match_level(
+            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+        for l in range(len(noise_pyramid)-1):
+            for b in range(nbands):
+                noise_pyramid[l]["b"][b] = match_level(
+                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+        metamer = self.pyramid_maker.reconstruct_from_pyramid(
+            noise_pyramid)
+        metamer = ycrcb_2_rgb(metamer)
+        # Crop to remove any padding
+        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+        return metamer
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, image_colorspace='RGB', visualise_loss=False) + +

+ + +
+ +

Calculates the Metameric Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + image_colorspace + – +
    +
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
    +                accepted values: RGB, YCrCb.
    +
    +
    +
  • +
  • + visualise_loss + – +
    +
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
+    """ 
+    Calculates the Metameric Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    image_colorspace    : str
+                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                            accepted values: RGB, YCrCb.
+    visualise_loss      : bool
+                            Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamericLossUniform", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+    # If input is RGB, convert to YCrCb.
+    if image.size(1) == 3 and image_colorspace == "RGB":
+        image = rgb_2_ycrcb(image)
+        target = rgb_2_ycrcb(target)
+    if self.target is None:
+        self.target = torch.zeros(target.shape).to(target.device)
+    if type(target) == type(self.target):
+        if not torch.all(torch.eq(target, self.target)):
+            self.target = target.detach().clone()
+            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
+            self.target = target.detach().clone()
+        image_stats = self.calc_statsmaps(image, self.pooling_size)
+
+        if visualise_loss:
+            self.visualise_loss_map(image_stats)
+        loss = self.metameric_loss_stats(
+            image_stats, self.target_stats)
+        return loss
+    else:
+        raise Exception("Target of incorrect type")
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2) + +

+ + +
+ + + +

Parameters:

+
    +
  • + pooling_size + – +
    +
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                      Number of levels of the steerable pyramid. Note that the image is padded
    +                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                  too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                  Increasing this will increase runtime.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
+    """
+
+    Parameters
+    ----------
+    pooling_size            : int
+                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
+    n_pyramid_levels        : int 
+                              Number of levels of the steerable pyramid. Note that the image is padded
+                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                              too high will slow down the calculation a lot.
+    n_orientations          : int 
+                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                              Increasing this will increase runtime.
+
+    """
+    self.target = None
+    self.device = device
+    self.pyramid_maker = None
+    self.pooling_size = pooling_size
+    self.n_pyramid_levels = n_pyramid_levels
+    self.n_orientations = n_orientations
+
+
+
+ +
+ +
+ + +

+ gen_metamer(image) + +

+ + +
+ +

Generates a metamer for an image, following the method in this paper +This function can be used on its own to generate a metamer for a desired image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +metamer ( tensor +) – +
    +

    The generated metamer image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
def gen_metamer(self, image):
+    """ 
+    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+    This function can be used on its own to generate a metamer for a desired image.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+
+    Returns
+    -------
+    metamer : torch.tensor
+              The generated metamer image
+    """
+    image = rgb_2_ycrcb(image)
+    image_size = image.size()
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+
+    target_stats = self.calc_statsmaps(
+        image, self.pooling_size)
+    target_means = target_stats[::2]
+    target_stdevs = target_stats[1::2]
+    torch.manual_seed(0)
+    noise_image = torch.rand_like(image)
+    noise_pyramid = self.pyramid_maker.construct_pyramid(
+        noise_image, self.n_pyramid_levels)
+    input_pyramid = self.pyramid_maker.construct_pyramid(
+        image, self.n_pyramid_levels)
+
+    def match_level(input_level, target_mean, target_std):
+        level = input_level.clone()
+        level -= torch.mean(level)
+        input_std = torch.sqrt(torch.mean(level * level))
+        eps = 1e-6
+        # Safeguard against divide by zero
+        input_std[input_std < eps] = eps
+        level /= input_std
+        level *= target_std
+        level += target_mean
+        return level
+
+    nbands = len(noise_pyramid[0]["b"])
+    noise_pyramid[0]["h"] = match_level(
+        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+    for l in range(len(noise_pyramid)-1):
+        for b in range(nbands):
+            noise_pyramid[l]["b"][b] = match_level(
+                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+    metamer = self.pyramid_maker.reconstruct_from_pyramid(
+        noise_pyramid)
+    metamer = ycrcb_2_rgb(metamer)
+    # Crop to remove any padding
+    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+    return metamer
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ PSNR + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
class PSNR(nn.Module):
+    '''
+    A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(PSNR, self).__init__()
+
+    def forward(self, predictions, targets, peak_value = 1.0):
+        """
+        A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        Image to be tested.
+        targets       : torch.tensor
+                        Ground truth image.
+        peak_value    : float
+                        Peak value that given tensors could have.
+
+        Returns
+        -------
+        result        : torch.tensor
+                        Peak-signal-to-noise ratio.
+        """
+        mse = torch.mean((targets - predictions) ** 2)
+        result = 20 * torch.log10(peak_value / torch.sqrt(mse))
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets, peak_value=1.0) + +

+ + +
+ +

A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

+ + +

Parameters:

+
    +
  • + predictions + – +
    +
            Image to be tested.
    +
    +
    +
  • +
  • + targets + – +
    +
            Ground truth image.
    +
    +
    +
  • +
  • + peak_value + – +
    +
            Peak value that given tensors could have.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Peak-signal-to-noise ratio.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
def forward(self, predictions, targets, peak_value = 1.0):
+    """
+    A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    Image to be tested.
+    targets       : torch.tensor
+                    Ground truth image.
+    peak_value    : float
+                    Peak value that given tensors could have.
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Peak-signal-to-noise ratio.
+    """
+    mse = torch.mean((targets - predictions) ** 2)
+    result = 20 * torch.log10(peak_value / torch.sqrt(mse))
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ RadiallyVaryingBlur + + +

+ + +
+ + +

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

+

The blur is accelerated by generating and sampling from MIP maps of the input image.

+

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

+

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

+ + + + + + +
+ Source code in odak/learn/perception/radially_varying_blur.py +
  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
class RadiallyVaryingBlur():
+    """ 
+
+    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().
+
+    The blur is accelerated by generating and sampling from MIP maps of the input image.
+
+    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.
+
+    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.
+
+    """
+
+    def __init__(self):
+        self.lod_map = None
+        self.equi = None
+
+    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
+        """
+        Apply the radially varying blur to an image.
+
+        Parameters
+        ----------
+
+        image                   : torch.tensor
+                                    The image to blur, in NCHW format.
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+                                    Ignored in equirectangular mode (equi==True)
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+                                    Ignored in equirectangular mode (equi==True)
+        centre                  : tuple of floats
+                                    The centre of the radially varying blur (the gaze location).
+                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
+                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        equi                    : bool
+                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The centre argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+
+        Returns
+        -------
+
+        output                  : torch.tensor
+                                    The blurred image
+        """
+        size = (image.size(-2), image.size(-1))
+
+        # LOD map caching
+        if self.lod_map is None or\
+                self.size != size or\
+                self.n_channels != image.size(1) or\
+                self.alpha != alpha or\
+                self.real_image_width != real_image_width or\
+                self.real_viewing_distance != real_viewing_distance or\
+                self.centre != centre or\
+                self.mode != mode or\
+                self.equi != equi:
+            if not equi:
+                self.lod_map = make_pooling_size_map_lod(
+                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
+            else:
+                self.lod_map = make_equi_pooling_size_map_lod(
+                    centre, (image.size(-2), image.size(-1)), alpha, mode)
+            self.size = size
+            self.n_channels = image.size(1)
+            self.alpha = alpha
+            self.real_image_width = real_image_width
+            self.real_viewing_distance = real_viewing_distance
+            self.centre = centre
+            self.lod_map = self.lod_map.to(image.device)
+            self.lod_fraction = torch.fmod(self.lod_map, 1.0)
+            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
+                1, image.size(1), 1, 1)
+            self.mode = mode
+            self.equi = equi
+
+        if self.lod_map.device != image.device:
+            self.lod_map = self.lod_map.to(image.device)
+        if self.lod_fraction.device != image.device:
+            self.lod_fraction = self.lod_fraction.to(image.device)
+
+        mipmap = [image]
+        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
+            mipmap.append(torch.nn.functional.interpolate(
+                mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
+        if mipmap[-1].size(-1) == 2:
+            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
+            mipmap.append(final_mip)
+        if mipmap[-1].size(-2) == 2:
+            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
+            mipmap.append(final_mip)
+
+        for l in range(len(mipmap)):
+            if l == len(mipmap)-1:
+                mipmap[l] = mipmap[l] * \
+                    torch.ones(image.size(), device=image.device)
+            else:
+                for l2 in range(l-1, -1, -1):
+                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
+                        image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+
+        output = torch.zeros(image.size(), device=image.device)
+        for l in range(len(mipmap)):
+            if l == 0:
+                mask = self.lod_map < (l+1)
+            elif l == len(mipmap)-1:
+                mask = self.lod_map >= l
+            else:
+                mask = torch.logical_and(
+                    self.lod_map >= l, self.lod_map < (l+1))
+
+            if l == len(mipmap)-1:
+                blended_levels = mipmap[l]
+            else:
+                blended_levels = (1 - self.lod_fraction) * \
+                    mipmap[l] + self.lod_fraction*mipmap[l+1]
+            mask = mask[None, None, ...]
+            mask = mask.repeat(1, image.size(1), 1, 1)
+            output[mask] = blended_levels[mask]
+
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False) + +

+ + +
+ +

Apply the radially varying blur to an image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                        The image to blur, in NCHW format.
    +
    +
    +
  • +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +                    Ignored in equirectangular mode (equi==True)
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +                    Ignored in equirectangular mode (equi==True)
    +
    +
    +
  • +
  • + centre + – +
    +
                        The centre of the radially varying blur (the gaze location).
    +                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
    +                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The centre argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    The blurred image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/radially_varying_blur.py +
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
+    """
+    Apply the radially varying blur to an image.
+
+    Parameters
+    ----------
+
+    image                   : torch.tensor
+                                The image to blur, in NCHW format.
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+                                Ignored in equirectangular mode (equi==True)
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+                                Ignored in equirectangular mode (equi==True)
+    centre                  : tuple of floats
+                                The centre of the radially varying blur (the gaze location).
+                                Should be a tuple of floats containing normalised image coordinates in range [0,1]
+                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    equi                    : bool
+                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The centre argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+
+    Returns
+    -------
+
+    output                  : torch.tensor
+                                The blurred image
+    """
+    size = (image.size(-2), image.size(-1))
+
+    # LOD map caching
+    if self.lod_map is None or\
+            self.size != size or\
+            self.n_channels != image.size(1) or\
+            self.alpha != alpha or\
+            self.real_image_width != real_image_width or\
+            self.real_viewing_distance != real_viewing_distance or\
+            self.centre != centre or\
+            self.mode != mode or\
+            self.equi != equi:
+        if not equi:
+            self.lod_map = make_pooling_size_map_lod(
+                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
+        else:
+            self.lod_map = make_equi_pooling_size_map_lod(
+                centre, (image.size(-2), image.size(-1)), alpha, mode)
+        self.size = size
+        self.n_channels = image.size(1)
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.centre = centre
+        self.lod_map = self.lod_map.to(image.device)
+        self.lod_fraction = torch.fmod(self.lod_map, 1.0)
+        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
+            1, image.size(1), 1, 1)
+        self.mode = mode
+        self.equi = equi
+
+    if self.lod_map.device != image.device:
+        self.lod_map = self.lod_map.to(image.device)
+    if self.lod_fraction.device != image.device:
+        self.lod_fraction = self.lod_fraction.to(image.device)
+
+    mipmap = [image]
+    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
+        mipmap.append(torch.nn.functional.interpolate(
+            mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
+    if mipmap[-1].size(-1) == 2:
+        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
+        mipmap.append(final_mip)
+    if mipmap[-1].size(-2) == 2:
+        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
+        mipmap.append(final_mip)
+
+    for l in range(len(mipmap)):
+        if l == len(mipmap)-1:
+            mipmap[l] = mipmap[l] * \
+                torch.ones(image.size(), device=image.device)
+        else:
+            for l2 in range(l-1, -1, -1):
+                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
+                    image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+
+    output = torch.zeros(image.size(), device=image.device)
+    for l in range(len(mipmap)):
+        if l == 0:
+            mask = self.lod_map < (l+1)
+        elif l == len(mipmap)-1:
+            mask = self.lod_map >= l
+        else:
+            mask = torch.logical_and(
+                self.lod_map >= l, self.lod_map < (l+1))
+
+        if l == len(mipmap)-1:
+            blended_levels = mipmap[l]
+        else:
+            blended_levels = (1 - self.lod_fraction) * \
+                mipmap[l] + self.lod_fraction*mipmap[l+1]
+        mask = mask[None, None, ...]
+        mask = mask.repeat(1, image.size(1), 1, 1)
+        output[mask] = blended_levels[mask]
+
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SSIM + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate structural similarity index of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
class SSIM(nn.Module):
+    '''
+    A class to calculate structural similarity index of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(SSIM, self).__init__()
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions : torch.tensor
+                      The predicted images.
+        targets     : torch.tensor
+                      The ground truth images.
+
+        Returns
+        -------
+        result      : torch.tensor 
+                      The computed SSIM value if successful, otherwise 0.0.
+        """
+        try:
+            from torchmetrics.functional.image import structural_similarity_index_measure
+            if len(predictions.shape) == 3:
+                predictions = predictions.unsqueeze(0)
+                targets = targets.unsqueeze(0)
+            l_SSIM = structural_similarity_index_measure(predictions, targets)
+            return l_SSIM
+        except Exception as e:
+            logging.warning('SSIM failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + (tensor) + – +
    +
          The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
          The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed SSIM value if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions : torch.tensor
+                  The predicted images.
+    targets     : torch.tensor
+                  The ground truth images.
+
+    Returns
+    -------
+    result      : torch.tensor 
+                  The computed SSIM value if successful, otherwise 0.0.
+    """
+    try:
+        from torchmetrics.functional.image import structural_similarity_index_measure
+        if len(predictions.shape) == 3:
+            predictions = predictions.unsqueeze(0)
+            targets = targets.unsqueeze(0)
+        l_SSIM = structural_similarity_index_measure(predictions, targets)
+        return l_SSIM
+    except Exception as e:
+        logging.warning('SSIM failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SpatialSteerablePyramid + + +

+ + +
+ + +

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) +as opposed to multiplication in the Fourier domain. +This has a number of optimisations over previous implementations that increase efficiency, but introduce some +reconstruction error.

+ + + + + + +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
class SpatialSteerablePyramid():
+    """
+    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)
+    as opposed to multiplication in the Fourier domain.
+    This has a number of optimisations over previous implementations that increase efficiency, but introduce some
+    reconstruction error.
+    """
+
+
+    def __init__(self, use_bilinear_downup=True, n_channels=1,
+                 filter_size=9, n_orientations=6, filter_type="full",
+                 device=torch.device('cpu')):
+        """
+        Parameters
+        ----------
+
+        use_bilinear_downup     : bool
+                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach
+                                    of applying a large lowpass kernel and sampling even rows/columns
+        n_channels              : int
+                                    Number of channels in the input images (e.g. 3 for RGB input)
+        filter_size             : int
+                                    Desired size of filters (e.g. 3 will use 3x3 filters).
+        n_orientations          : int
+                                    Number of oriented bands in each level of the pyramid.
+        filter_type             : str
+                                    This can be used to select smaller filters than the original ones if desired.
+                                    full: Original filter sizes
+                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+        device                  : torch.device
+                                    torch device the input images will be supplied from.
+        """
+        self.use_bilinear_downup = use_bilinear_downup
+        self.device = device
+
+        filters = get_steerable_pyramid_filters(
+            filter_size, n_orientations, filter_type)
+
+        def make_pad(filter):
+            filter_size = filter.size(-1)
+            pad_amt = (filter_size-1) // 2
+            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))
+
+        if not self.use_bilinear_downup:
+            self.filt_l = filters["l"].to(device)
+            self.pad_l = make_pad(self.filt_l)
+        self.filt_l0 = filters["l0"].to(device)
+        self.pad_l0 = make_pad(self.filt_l0)
+        self.filt_h0 = filters["h0"].to(device)
+        self.pad_h0 = make_pad(self.filt_h0)
+        for b in range(len(filters["b"])):
+            filters["b"][b] = filters["b"][b].to(device)
+        self.band_filters = filters["b"]
+        self.pad_b = make_pad(self.band_filters[0])
+
+        if n_channels != 1:
+            def add_channels_to_filter(filter):
+                padded = torch.zeros(n_channels, n_channels, filter.size()[
+                                     2], filter.size()[3]).to(device)
+                for channel in range(n_channels):
+                    padded[channel, channel, :, :] = filter
+                return padded
+            self.filt_h0 = add_channels_to_filter(self.filt_h0)
+            for b in range(len(self.band_filters)):
+                self.band_filters[b] = add_channels_to_filter(
+                    self.band_filters[b])
+            self.filt_l0 = add_channels_to_filter(self.filt_l0)
+            if not self.use_bilinear_downup:
+                self.filt_l = add_channels_to_filter(self.filt_l)
+
+    def construct_pyramid(self, image, n_levels, multiple_highpass=False):
+        """
+        Constructs and returns a steerable pyramid for the provided image.
+
+        Parameters
+        ----------
+
+        image               : torch.tensor
+                                The input image, in NCHW format. The number of channels C should match num_channels
+                                when the pyramid maker was created.
+        n_levels            : int
+                                Number of levels in the constructed steerable pyramid.
+        multiple_highpass   : bool
+                                If true, computes a highpass for each level of the pyramid.
+                                These extra levels are redundant (not used for reconstruction).
+
+        Returns
+        -------
+
+        pyramid             : list of dicts of torch.tensor
+                                The computed steerable pyramid.
+                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
+                                Each level is stored as a dict, with the following keys:
+                                "h" Highpass residual
+                                "l" Lowpass residual
+                                "b" Oriented bands (a list of torch.tensor)
+        """
+        pyramid = []
+
+        # Make level 0, containing highpass, lowpass and the bands
+        level0 = {}
+        level0['h'] = torch.nn.functional.conv2d(
+            self.pad_h0(image), self.filt_h0)
+        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+        level0['l'] = lowpass.clone()
+        bands = []
+        for filt_b in self.band_filters:
+            bands.append(torch.nn.functional.conv2d(
+                self.pad_b(lowpass), filt_b))
+        level0['b'] = bands
+        pyramid.append(level0)
+
+        # Make intermediate levels
+        for l in range(n_levels-2):
+            level = {}
+            if self.use_bilinear_downup:
+                lowpass = torch.nn.functional.interpolate(
+                    lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+            else:
+                lowpass = torch.nn.functional.conv2d(
+                    self.pad_l(lowpass), self.filt_l)
+                lowpass = lowpass[:, :, ::2, ::2]
+            level['l'] = lowpass.clone()
+            bands = []
+            for filt_b in self.band_filters:
+                bands.append(torch.nn.functional.conv2d(
+                    self.pad_b(lowpass), filt_b))
+            level['b'] = bands
+            if multiple_highpass:
+                level['h'] = torch.nn.functional.conv2d(
+                    self.pad_h0(lowpass), self.filt_h0)
+            pyramid.append(level)
+
+        # Make final level (lowpass residual)
+        level = {}
+        if self.use_bilinear_downup:
+            lowpass = torch.nn.functional.interpolate(
+                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+        else:
+            lowpass = torch.nn.functional.conv2d(
+                self.pad_l(lowpass), self.filt_l)
+            lowpass = lowpass[:, :, ::2, ::2]
+        level['l'] = lowpass
+        pyramid.append(level)
+
+        return pyramid
+
+    def reconstruct_from_pyramid(self, pyramid):
+        """
+        Reconstructs an input image from a steerable pyramid.
+
+        Parameters
+        ----------
+
+        pyramid : list of dicts of torch.tensor
+                    The steerable pyramid.
+                    Should be in the same format as output by construct_steerable_pyramid().
+                    The number of channels should match num_channels when the pyramid maker was created.
+
+        Returns
+        -------
+
+        image   : torch.tensor
+                    The reconstructed image, in NCHW format.         
+        """
+        def upsample(image, size):
+            if self.use_bilinear_downup:
+                return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            else:
+                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
+                                    2]*2, image.size()[3]*2)).to(self.device)
+                zeros[:, :, ::2, ::2] = image
+                zeros = torch.nn.functional.conv2d(
+                    self.pad_l(zeros), self.filt_l)
+                return zeros
+
+        image = pyramid[-1]['l']
+        for level in reversed(pyramid[:-1]):
+            image = upsample(image, level['b'][0].size()[2:])
+            for b in range(len(level['b'])):
+                b_filtered = torch.nn.functional.conv2d(
+                    self.pad_b(level['b'][b]), -self.band_filters[b])
+                image += b_filtered
+
+        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+        image += torch.nn.functional.conv2d(
+            self.pad_h0(pyramid[0]['h']), self.filt_h0)
+
+        return image
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + use_bilinear_downup + – +
    +
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach
    +                    of applying a large lowpass kernel and sampling even rows/columns
    +
    +
    +
  • +
  • + n_channels + – +
    +
                        Number of channels in the input images (e.g. 3 for RGB input)
    +
    +
    +
  • +
  • + filter_size + – +
    +
                        Desired size of filters (e.g. 3 will use 3x3 filters).
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of oriented bands in each level of the pyramid.
    +
    +
    +
  • +
  • + filter_type + – +
    +
                        This can be used to select smaller filters than the original ones if desired.
    +                    full: Original filter sizes
    +                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
    +                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    +
    +
    +
  • +
  • + device + – +
    +
                        torch device the input images will be supplied from.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def __init__(self, use_bilinear_downup=True, n_channels=1,
+             filter_size=9, n_orientations=6, filter_type="full",
+             device=torch.device('cpu')):
+    """
+    Parameters
+    ----------
+
+    use_bilinear_downup     : bool
+                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach
+                                of applying a large lowpass kernel and sampling even rows/columns
+    n_channels              : int
+                                Number of channels in the input images (e.g. 3 for RGB input)
+    filter_size             : int
+                                Desired size of filters (e.g. 3 will use 3x3 filters).
+    n_orientations          : int
+                                Number of oriented bands in each level of the pyramid.
+    filter_type             : str
+                                This can be used to select smaller filters than the original ones if desired.
+                                full: Original filter sizes
+                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+    device                  : torch.device
+                                torch device the input images will be supplied from.
+    """
+    self.use_bilinear_downup = use_bilinear_downup
+    self.device = device
+
+    filters = get_steerable_pyramid_filters(
+        filter_size, n_orientations, filter_type)
+
+    def make_pad(filter):
+        filter_size = filter.size(-1)
+        pad_amt = (filter_size-1) // 2
+        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))
+
+    if not self.use_bilinear_downup:
+        self.filt_l = filters["l"].to(device)
+        self.pad_l = make_pad(self.filt_l)
+    self.filt_l0 = filters["l0"].to(device)
+    self.pad_l0 = make_pad(self.filt_l0)
+    self.filt_h0 = filters["h0"].to(device)
+    self.pad_h0 = make_pad(self.filt_h0)
+    for b in range(len(filters["b"])):
+        filters["b"][b] = filters["b"][b].to(device)
+    self.band_filters = filters["b"]
+    self.pad_b = make_pad(self.band_filters[0])
+
+    if n_channels != 1:
+        def add_channels_to_filter(filter):
+            padded = torch.zeros(n_channels, n_channels, filter.size()[
+                                 2], filter.size()[3]).to(device)
+            for channel in range(n_channels):
+                padded[channel, channel, :, :] = filter
+            return padded
+        self.filt_h0 = add_channels_to_filter(self.filt_h0)
+        for b in range(len(self.band_filters)):
+            self.band_filters[b] = add_channels_to_filter(
+                self.band_filters[b])
+        self.filt_l0 = add_channels_to_filter(self.filt_l0)
+        if not self.use_bilinear_downup:
+            self.filt_l = add_channels_to_filter(self.filt_l)
+
+
+
+ +
+ +
+ + +

+ construct_pyramid(image, n_levels, multiple_highpass=False) + +

+ + +
+ +

Constructs and returns a steerable pyramid for the provided image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    The input image, in NCHW format. The number of channels C should match num_channels
    +                when the pyramid maker was created.
    +
    +
    +
  • +
  • + n_levels + – +
    +
                    Number of levels in the constructed steerable pyramid.
    +
    +
    +
  • +
  • + multiple_highpass + – +
    +
                    If true, computes a highpass for each level of the pyramid.
    +                These extra levels are redundant (not used for reconstruction).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pyramid ( list of dicts of torch.tensor +) – +
    +

    The computed steerable pyramid. +Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. +Each level is stored as a dict, with the following keys: +"h" Highpass residual +"l" Lowpass residual +"b" Oriented bands (a list of torch.tensor)

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def construct_pyramid(self, image, n_levels, multiple_highpass=False):
+    """
+    Constructs and returns a steerable pyramid for the provided image.
+
+    Parameters
+    ----------
+
+    image               : torch.tensor
+                            The input image, in NCHW format. The number of channels C should match num_channels
+                            when the pyramid maker was created.
+    n_levels            : int
+                            Number of levels in the constructed steerable pyramid.
+    multiple_highpass   : bool
+                            If true, computes a highpass for each level of the pyramid.
+                            These extra levels are redundant (not used for reconstruction).
+
+    Returns
+    -------
+
+    pyramid             : list of dicts of torch.tensor
+                            The computed steerable pyramid.
+                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
+                            Each level is stored as a dict, with the following keys:
+                            "h" Highpass residual
+                            "l" Lowpass residual
+                            "b" Oriented bands (a list of torch.tensor)
+    """
+    pyramid = []
+
+    # Make level 0, containing highpass, lowpass and the bands
+    level0 = {}
+    level0['h'] = torch.nn.functional.conv2d(
+        self.pad_h0(image), self.filt_h0)
+    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+    level0['l'] = lowpass.clone()
+    bands = []
+    for filt_b in self.band_filters:
+        bands.append(torch.nn.functional.conv2d(
+            self.pad_b(lowpass), filt_b))
+    level0['b'] = bands
+    pyramid.append(level0)
+
+    # Make intermediate levels
+    for l in range(n_levels-2):
+        level = {}
+        if self.use_bilinear_downup:
+            lowpass = torch.nn.functional.interpolate(
+                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+        else:
+            lowpass = torch.nn.functional.conv2d(
+                self.pad_l(lowpass), self.filt_l)
+            lowpass = lowpass[:, :, ::2, ::2]
+        level['l'] = lowpass.clone()
+        bands = []
+        for filt_b in self.band_filters:
+            bands.append(torch.nn.functional.conv2d(
+                self.pad_b(lowpass), filt_b))
+        level['b'] = bands
+        if multiple_highpass:
+            level['h'] = torch.nn.functional.conv2d(
+                self.pad_h0(lowpass), self.filt_h0)
+        pyramid.append(level)
+
+    # Make final level (lowpass residual)
+    level = {}
+    if self.use_bilinear_downup:
+        lowpass = torch.nn.functional.interpolate(
+            lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+    else:
+        lowpass = torch.nn.functional.conv2d(
+            self.pad_l(lowpass), self.filt_l)
+        lowpass = lowpass[:, :, ::2, ::2]
+    level['l'] = lowpass
+    pyramid.append(level)
+
+    return pyramid
+
+
+
+ +
+ +
+ + +

+ reconstruct_from_pyramid(pyramid) + +

+ + +
+ +

Reconstructs an input image from a steerable pyramid.

+ + +

Parameters:

+
    +
  • + pyramid + (list of dicts of torch.tensor) + – +
    +
        The steerable pyramid.
    +    Should be in the same format as output by construct_steerable_pyramid().
    +    The number of channels should match num_channels when the pyramid maker was created.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( tensor +) – +
    +

    The reconstructed image, in NCHW format.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def reconstruct_from_pyramid(self, pyramid):
+    """
+    Reconstructs an input image from a steerable pyramid.
+
+    Parameters
+    ----------
+
+    pyramid : list of dicts of torch.tensor
+                The steerable pyramid.
+                Should be in the same format as output by construct_steerable_pyramid().
+                The number of channels should match num_channels when the pyramid maker was created.
+
+    Returns
+    -------
+
+    image   : torch.tensor
+                The reconstructed image, in NCHW format.         
+    """
+    def upsample(image, size):
+        if self.use_bilinear_downup:
+            return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+        else:
+            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
+                                2]*2, image.size()[3]*2)).to(self.device)
+            zeros[:, :, ::2, ::2] = image
+            zeros = torch.nn.functional.conv2d(
+                self.pad_l(zeros), self.filt_l)
+            return zeros
+
+    image = pyramid[-1]['l']
+    for level in reversed(pyramid[:-1]):
+        image = upsample(image, level['b'][0].size()[2:])
+        for b in range(len(level['b'])):
+            b_filtered = torch.nn.functional.conv2d(
+                self.pad_b(level['b'][b]), -self.band_filters[b])
+            image += b_filtered
+
+    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+    image += torch.nn.functional.conv2d(
+        self.pad_h0(pyramid[0]['h']), self.filt_h0)
+
+    return image
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ display_color_hvs + + +

+ + +
+ + + + + + + +
+ Source code in odak/learn/perception/color_conversion.py +
 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
class display_color_hvs():
+
+    def __init__(
+                 self,
+                 resolution = [1920, 1080],
+                 distance_from_screen = 800,
+                 pixel_pitch = 0.311,
+                 read_spectrum = 'tensor',
+                 primaries_spectrum = torch.rand(3, 301),
+                 device = torch.device('cpu')):
+        '''
+        Parameters
+        ----------
+        resolution                  : list
+                                      Resolution of the display in pixels.
+        distance_from_screen        : int
+                                      Distance from the screen in mm.
+        pixel_pitch                 : float
+                                      Pixel pitch of the display in mm.
+        read_spectrum               : str
+                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
+        device                      : torch.device
+                                      Device to run the code on. Default is None which means the code will run on CPU.
+
+        '''
+        self.device = device
+        self.read_spectrum = read_spectrum
+        self.primaries_spectrum = primaries_spectrum.to(self.device)
+        self.resolution = resolution
+        self.distance_from_screen = distance_from_screen
+        self.pixel_pitch = pixel_pitch
+        self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
+        self.lms_tensor = self.construct_matrix_lms(
+                                                    self.l_normalized,
+                                                    self.m_normalized,
+                                                    self.s_normalized
+                                                   )   
+        self.primaries_tensor = self.construct_matrix_primaries(
+                                                                self.l_normalized,
+                                                                self.m_normalized,
+                                                                self.s_normalized
+                                                               )   
+        return
+
+
+    def __call__(self, input_image, ground_truth, gaze=None):
+        """
+        Evaluating an input image against a target ground truth image for a given gaze of a viewer.
+        """
+        lms_image_second = self.primaries_to_lms(input_image.to(self.device))
+        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
+        lms_image_third = self.second_to_third_stage(lms_image_second)
+        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
+        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
+        return loss_metamer_color
+
+
+    def initialize_cones_normalized(self):
+        """
+        Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 
+
+        Returns
+        -------
+        l_cone_n                     : torch.tensor
+                                       Normalised L cone distribution.
+        m_cone_n                     : torch.tensor
+                                       Normalised M cone distribution.
+        s_cone_n                     : torch.tensor
+                                       Normalised S cone distribution.
+        """
+        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+        dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
+        dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))
+        dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))
+
+        l_cone_n = dist_l / dist_l.max()
+        m_cone_n = dist_m / dist_m.max()
+        s_cone_n = dist_s / dist_s.max()
+        return l_cone_n, m_cone_n, s_cone_n
+
+
+    def initialize_rgb_backlight_spectrum(self):
+        """
+        Internal function to initialize baclight spectrum for color primaries. 
+
+        Returns
+        -------
+        red_spectrum                 : torch.tensor
+                                       Normalised backlight spectrum for red color primary.
+        green_spectrum               : torch.tensor
+                                       Normalised backlight spectrum for green color primary.
+        blue_spectrum                : torch.tensor
+                                       Normalised backlight spectrum for blue color primary.
+        """
+        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+        red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))
+        green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))
+        blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))
+
+        red_spectrum = red_spectrum / red_spectrum.max()
+        green_spectrum = green_spectrum / green_spectrum.max()
+        blue_spectrum = blue_spectrum / blue_spectrum.max()
+
+        return red_spectrum, green_spectrum, blue_spectrum
+
+
+    def initialize_random_spectrum_normalized(self, dataset):
+        """
+        Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 
+
+        Parameters
+        ----------
+        dataset                                : torch.tensor 
+                                                 spectrum value against wavelength 
+        """
+        dataset = torch.swapaxes(dataset, 0, 1)
+        x_spectrum = torch.linspace(400, 700, steps = 301) - 550
+        y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))
+        max_spectrum = torch.max(y_spectrum)
+        y_spectrum /= max_spectrum
+
+        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
+            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))
+
+        def function(x, weights): 
+            return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
+
+        weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)
+        optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)
+
+        def closure():
+            optimizer.zero_grad()
+            output = function(x_spectrum, weights)
+            loss = F.mse_loss(output, y_spectrum)
+            loss.backward()
+            return loss
+        optimizer.step(closure)
+        spectrum = function(x_spectrum, weights)
+        return spectrum.detach().to(self.device)
+
+
+    def display_spectrum_response(wavelength, function):
+        """
+        Internal function to provide light spectrum response at particular wavelength
+
+        Parameters
+        ----------
+        wavelength                          : torch.tensor
+                                              Wavelength in nm [400...700]
+        function                            : torch.tensor
+                                              Display light spectrum distribution function
+
+        Returns
+        -------
+        ligth_response_dict                  : float
+                                               Display light spectrum response value
+        """
+        wavelength = int(round(wavelength, 0))
+        if wavelength >= 400 and wavelength <= 700:
+            return function[wavelength - 400].item()
+        elif wavelength < 400:
+            return function[0].item()
+        else:
+            return function[300].item()
+
+
+    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
+        """
+        Internal function to calculate cone response at particular light spectrum. 
+
+        Parameters
+        ----------
+        cone_spectrum                         : torch.tensor
+                                                Spectrum, Wavelength [2,300] tensor 
+        light_spectrum                        : torch.tensor
+                                                Spectrum, Wavelength [2,300] tensor 
+
+
+        Returns
+        -------
+        response_to_spectrum                  : float
+                                                Response of cone to light spectrum [1x1] 
+        """
+        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
+        response_to_spectrum = torch.sum(response_to_spectrum)
+        return response_to_spectrum.item()
+
+
+    def construct_matrix_lms(self, l_response, m_response, s_response):
+        '''
+        Internal function to calculate cone  response at particular light spectrum. 
+
+        Parameters
+        ----------
+        l_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        m_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        s_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+        Returns
+        -------
+        lms_image_tensor                      : torch.tensor
+                                                3x3 LMSrgb tensor
+
+        '''
+        if self.read_spectrum == 'tensor':
+            logging.warning('Tensor primary spectrum is used')
+            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+        else:
+            logging.warning("No Spectrum data is provided")
+
+        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
+        for i in range(self.primaries_spectrum.shape[0]):
+            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
+            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
+            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
+        return self.lms_tensor    
+
+
+    def construct_matrix_primaries(self, l_response, m_response, s_response):
+        '''
+        Internal function to calculate cone  response at particular light spectrum. 
+
+        Parameters
+        ----------
+        l_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        m_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        s_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+        Returns
+        -------
+        lms_image_tensor                      : torch.tensor
+                                                3x3 LMSrgb tensor
+
+        '''
+        if self.read_spectrum == 'tensor':
+            logging.warning('Tensor primary spectrum is used')
+            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+        else:
+            logging.warning("No Spectrum data is provided")
+
+        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
+        for i in range(self.primaries_spectrum.shape[0]):
+            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
+                                                                         l_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        )
+            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
+                                                                         m_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        )
+            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
+                                                                         s_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        ) 
+        return self.primaries_tensor    
+
+
+    def primaries_to_lms(self, primaries):
+        """
+        Internal function to convert primaries space to LMS space 
+
+        Parameters
+        ----------
+        primaries                              : torch.tensor
+                                                 Primaries data to be transformed to LMS space [BxPHxW]
+
+
+        Returns
+        -------
+        lms_color                              : torch.tensor
+                                                 LMS data transformed from Primaries space [BxPxHxW]
+        """                
+        primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)
+        lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)
+        lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)
+        return lms_color
+
+
+    def lms_to_primaries(self, lms_color_tensor):
+        """
+        Internal function to convert LMS image to primaries space
+
+        Parameters
+        ----------
+        lms_color_tensor                        : torch.tensor
+                                                  LMS data to be transformed to primaries space [Bx3xHxW]
+
+
+        Returns
+        -------
+        primaries                              : torch.tensor
+                                               : Primaries data transformed from LMS space [BxPxHxW]
+        """
+        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
+        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
+        unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
+        converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())
+        primaries = unflatten(converted_unflatten)     
+        primaries = primaries.permute(0, 3, 1, 2)   
+        return primaries
+
+
+    def second_to_third_stage(self, lms_image):
+        '''
+        This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], 
+        See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.
+
+        Parameters
+        ----------
+        lms_image                             : torch.tensor
+                                                 Image data at LMS space (second stage)
+
+        Returns
+        -------
+        third_stage                            : torch.tensor
+                                                 Image data at LMS space (third stage)
+
+        '''
+        third_stage = torch.zeros_like(lms_image)
+        third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]
+        third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]
+        third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]
+        return third_stage
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(input_image, ground_truth, gaze=None) + +

+ + +
+ +

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

+ +
+ Source code in odak/learn/perception/color_conversion.py +
55
+56
+57
+58
+59
+60
+61
+62
+63
+64
def __call__(self, input_image, ground_truth, gaze=None):
+    """
+    Evaluating an input image against a target ground truth image for a given gaze of a viewer.
+    """
+    lms_image_second = self.primaries_to_lms(input_image.to(self.device))
+    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
+    lms_image_third = self.second_to_third_stage(lms_image_second)
+    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
+    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
+    return loss_metamer_color
+
+
+
+ +
+ +
+ + +

+ __init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + resolution + – +
    +
                          Resolution of the display in pixels.
    +
    +
    +
  • +
  • + distance_from_screen + – +
    +
                          Distance from the screen in mm.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
                          Pixel pitch of the display in mm.
    +
    +
    +
  • +
  • + read_spectrum + – +
    +
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    +
    +
    +
  • +
  • + device + – +
    +
                          Device to run the code on. Default is None which means the code will run on CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def __init__(
+             self,
+             resolution = [1920, 1080],
+             distance_from_screen = 800,
+             pixel_pitch = 0.311,
+             read_spectrum = 'tensor',
+             primaries_spectrum = torch.rand(3, 301),
+             device = torch.device('cpu')):
+    '''
+    Parameters
+    ----------
+    resolution                  : list
+                                  Resolution of the display in pixels.
+    distance_from_screen        : int
+                                  Distance from the screen in mm.
+    pixel_pitch                 : float
+                                  Pixel pitch of the display in mm.
+    read_spectrum               : str
+                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
+    device                      : torch.device
+                                  Device to run the code on. Default is None which means the code will run on CPU.
+
+    '''
+    self.device = device
+    self.read_spectrum = read_spectrum
+    self.primaries_spectrum = primaries_spectrum.to(self.device)
+    self.resolution = resolution
+    self.distance_from_screen = distance_from_screen
+    self.pixel_pitch = pixel_pitch
+    self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
+    self.lms_tensor = self.construct_matrix_lms(
+                                                self.l_normalized,
+                                                self.m_normalized,
+                                                self.s_normalized
+                                               )   
+    self.primaries_tensor = self.construct_matrix_primaries(
+                                                            self.l_normalized,
+                                                            self.m_normalized,
+                                                            self.s_normalized
+                                                           )   
+    return
+
+
+
+ +
+ +
+ + +

+ cone_response_to_spectrum(cone_spectrum, light_spectrum) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + cone_spectrum + – +
    +
                                    Spectrum, Wavelength [2,300] tensor
    +
    +
    +
  • +
  • + light_spectrum + – +
    +
                                    Spectrum, Wavelength [2,300] tensor
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +response_to_spectrum ( float +) – +
    +

    Response of cone to light spectrum [1x1]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
+    """
+    Internal function to calculate cone response at particular light spectrum. 
+
+    Parameters
+    ----------
+    cone_spectrum                         : torch.tensor
+                                            Spectrum, Wavelength [2,300] tensor 
+    light_spectrum                        : torch.tensor
+                                            Spectrum, Wavelength [2,300] tensor 
+
+
+    Returns
+    -------
+    response_to_spectrum                  : float
+                                            Response of cone to light spectrum [1x1] 
+    """
+    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
+    response_to_spectrum = torch.sum(response_to_spectrum)
+    return response_to_spectrum.item()
+
+
+
+ +
+ +
+ + +

+ construct_matrix_lms(l_response, m_response, s_response) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + l_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + m_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + s_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_image_tensor ( tensor +) – +
    +

    3x3 LMSrgb tensor

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def construct_matrix_lms(self, l_response, m_response, s_response):
+    '''
+    Internal function to calculate cone  response at particular light spectrum. 
+
+    Parameters
+    ----------
+    l_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    m_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    s_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+    Returns
+    -------
+    lms_image_tensor                      : torch.tensor
+                                            3x3 LMSrgb tensor
+
+    '''
+    if self.read_spectrum == 'tensor':
+        logging.warning('Tensor primary spectrum is used')
+        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+    else:
+        logging.warning("No Spectrum data is provided")
+
+    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
+    for i in range(self.primaries_spectrum.shape[0]):
+        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
+        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
+        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
+    return self.lms_tensor    
+
+
+
+ +
+ +
+ + +

+ construct_matrix_primaries(l_response, m_response, s_response) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + l_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + m_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + s_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_image_tensor ( tensor +) – +
    +

    3x3 LMSrgb tensor

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def construct_matrix_primaries(self, l_response, m_response, s_response):
+    '''
+    Internal function to calculate cone  response at particular light spectrum. 
+
+    Parameters
+    ----------
+    l_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    m_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    s_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+    Returns
+    -------
+    lms_image_tensor                      : torch.tensor
+                                            3x3 LMSrgb tensor
+
+    '''
+    if self.read_spectrum == 'tensor':
+        logging.warning('Tensor primary spectrum is used')
+        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+    else:
+        logging.warning("No Spectrum data is provided")
+
+    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
+    for i in range(self.primaries_spectrum.shape[0]):
+        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
+                                                                     l_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    )
+        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
+                                                                     m_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    )
+        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
+                                                                     s_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    ) 
+    return self.primaries_tensor    
+
+
+
+ +
+ +
+ + +

+ display_spectrum_response(wavelength, function) + +

+ + +
+ +

Internal function to provide light spectrum response at particular wavelength

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
                                  Wavelength in nm [400...700]
    +
    +
    +
  • +
  • + function + – +
    +
                                  Display light spectrum distribution function
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ligth_response_dict ( float +) – +
    +

    Display light spectrum response value

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def display_spectrum_response(wavelength, function):
+    """
+    Internal function to provide light spectrum response at particular wavelength
+
+    Parameters
+    ----------
+    wavelength                          : torch.tensor
+                                          Wavelength in nm [400...700]
+    function                            : torch.tensor
+                                          Display light spectrum distribution function
+
+    Returns
+    -------
+    ligth_response_dict                  : float
+                                           Display light spectrum response value
+    """
+    wavelength = int(round(wavelength, 0))
+    if wavelength >= 400 and wavelength <= 700:
+        return function[wavelength - 400].item()
+    elif wavelength < 400:
+        return function[0].item()
+    else:
+        return function[300].item()
+
+
+
+ +
+ +
+ + +

+ initialize_cones_normalized() + +

+ + +
+ +

Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values.

+ + +

Returns:

+
    +
  • +l_cone_n ( tensor +) – +
    +

    Normalised L cone distribution.

    +
    +
  • +
  • +m_cone_n ( tensor +) – +
    +

    Normalised M cone distribution.

    +
    +
  • +
  • +s_cone_n ( tensor +) – +
    +

    Normalised S cone distribution.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
def initialize_cones_normalized(self):
+    """
+    Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 
+
+    Returns
+    -------
+    l_cone_n                     : torch.tensor
+                                   Normalised L cone distribution.
+    m_cone_n                     : torch.tensor
+                                   Normalised M cone distribution.
+    s_cone_n                     : torch.tensor
+                                   Normalised S cone distribution.
+    """
+    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+    dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
+    dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))
+    dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))
+
+    l_cone_n = dist_l / dist_l.max()
+    m_cone_n = dist_m / dist_m.max()
+    s_cone_n = dist_s / dist_s.max()
+    return l_cone_n, m_cone_n, s_cone_n
+
+
+
+ +
+ +
+ + +

+ initialize_random_spectrum_normalized(dataset) + +

+ + +
+ +

Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

+ + +

Parameters:

+
    +
  • + dataset + – +
    +
                                     spectrum value against wavelength
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def initialize_random_spectrum_normalized(self, dataset):
+    """
+    Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 
+
+    Parameters
+    ----------
+    dataset                                : torch.tensor 
+                                             spectrum value against wavelength 
+    """
+    dataset = torch.swapaxes(dataset, 0, 1)
+    x_spectrum = torch.linspace(400, 700, steps = 301) - 550
+    y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))
+    max_spectrum = torch.max(y_spectrum)
+    y_spectrum /= max_spectrum
+
+    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
+        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))
+
+    def function(x, weights): 
+        return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
+
+    weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)
+    optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)
+
+    def closure():
+        optimizer.zero_grad()
+        output = function(x_spectrum, weights)
+        loss = F.mse_loss(output, y_spectrum)
+        loss.backward()
+        return loss
+    optimizer.step(closure)
+    spectrum = function(x_spectrum, weights)
+    return spectrum.detach().to(self.device)
+
+
+
+ +
+ +
+ + +

+ initialize_rgb_backlight_spectrum() + +

+ + +
+ +

Internal function to initialize baclight spectrum for color primaries.

+ + +

Returns:

+
    +
  • +red_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for red color primary.

    +
    +
  • +
  • +green_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for green color primary.

    +
    +
  • +
  • +blue_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for blue color primary.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def initialize_rgb_backlight_spectrum(self):
+    """
+    Internal function to initialize baclight spectrum for color primaries. 
+
+    Returns
+    -------
+    red_spectrum                 : torch.tensor
+                                   Normalised backlight spectrum for red color primary.
+    green_spectrum               : torch.tensor
+                                   Normalised backlight spectrum for green color primary.
+    blue_spectrum                : torch.tensor
+                                   Normalised backlight spectrum for blue color primary.
+    """
+    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+    red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))
+    green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))
+    blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))
+
+    red_spectrum = red_spectrum / red_spectrum.max()
+    green_spectrum = green_spectrum / green_spectrum.max()
+    blue_spectrum = blue_spectrum / blue_spectrum.max()
+
+    return red_spectrum, green_spectrum, blue_spectrum
+
+
+
+ +
+ +
+ + +

+ lms_to_primaries(lms_color_tensor) + +

+ + +
+ +

Internal function to convert LMS image to primaries space

+ + +

Parameters:

+
    +
  • + lms_color_tensor + – +
    +
                                      LMS data to be transformed to primaries space [Bx3xHxW]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +primaries ( tensor +) – +
    +

    : Primaries data transformed from LMS space [BxPxHxW]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def lms_to_primaries(self, lms_color_tensor):
+    """
+    Internal function to convert LMS image to primaries space
+
+    Parameters
+    ----------
+    lms_color_tensor                        : torch.tensor
+                                              LMS data to be transformed to primaries space [Bx3xHxW]
+
+
+    Returns
+    -------
+    primaries                              : torch.tensor
+                                           : Primaries data transformed from LMS space [BxPxHxW]
+    """
+    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
+    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
+    unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
+    converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())
+    primaries = unflatten(converted_unflatten)     
+    primaries = primaries.permute(0, 3, 1, 2)   
+    return primaries
+
+
+
+ +
+ +
+ + +

+ primaries_to_lms(primaries) + +

+ + +
+ +

Internal function to convert primaries space to LMS space

+ + +

Parameters:

+
    +
  • + primaries + – +
    +
                                     Primaries data to be transformed to LMS space [BxPHxW]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_color ( tensor +) – +
    +

    LMS data transformed from Primaries space [BxPxHxW]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def primaries_to_lms(self, primaries):
+    """
+    Internal function to convert primaries space to LMS space 
+
+    Parameters
+    ----------
+    primaries                              : torch.tensor
+                                             Primaries data to be transformed to LMS space [BxPHxW]
+
+
+    Returns
+    -------
+    lms_color                              : torch.tensor
+                                             LMS data transformed from Primaries space [BxPxHxW]
+    """                
+    primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)
+    lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)
+    lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)
+    return lms_color
+
+
+
+ +
+ +
+ + +

+ second_to_third_stage(lms_image) + +

+ + +
+ +

This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], +See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.

+ + +

Parameters:

+
    +
  • + lms_image + – +
    +
                                     Image data at LMS space (second stage)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +third_stage ( tensor +) – +
    +

    Image data at LMS space (third stage)

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def second_to_third_stage(self, lms_image):
+    '''
+    This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], 
+    See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.
+
+    Parameters
+    ----------
+    lms_image                             : torch.tensor
+                                             Image data at LMS space (second stage)
+
+    Returns
+    -------
+    third_stage                            : torch.tensor
+                                             Image data at LMS space (third stage)
+
+    '''
+    third_stage = torch.zeros_like(lms_image)
+    third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]
+    third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]
+    third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]
+    return third_stage
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ color_map(input_image, target_image, model='Lab Stats') + +

+ + +
+ +

Internal function to map the color of an image to another image. +Reference: Color transfer between images, Reinhard et al., 2001.

+ + +

Parameters:

+
    +
  • + input_image + – +
    +
                  Input image in RGB color space [3 x m x n].
    +
    +
    +
  • +
  • + target_image + – +
    + +
    +
  • +
+ + +

Returns:

+
    +
  • +mapped_image ( Tensor +) – +
    +

    Input image with the color the distribution of the target image [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def color_map(input_image, target_image, model = 'Lab Stats'):
+    """
+    Internal function to map the color of an image to another image.
+    Reference: Color transfer between images, Reinhard et al., 2001.
+
+    Parameters
+    ----------
+    input_image         : torch.Tensor
+                          Input image in RGB color space [3 x m x n].
+    target_image        : torch.Tensor
+
+    Returns
+    -------
+    mapped_image           : torch.Tensor
+                             Input image with the color the distribution of the target image [3 x m x n].
+    """
+    if model == 'Lab Stats':
+        lab_input = srgb_to_lab(input_image)
+        lab_target = srgb_to_lab(target_image)
+        input_mean_L = torch.mean(lab_input[0, :, :])
+        input_mean_a = torch.mean(lab_input[1, :, :])
+        input_mean_b = torch.mean(lab_input[2, :, :])
+        input_std_L = torch.std(lab_input[0, :, :])
+        input_std_a = torch.std(lab_input[1, :, :])
+        input_std_b = torch.std(lab_input[2, :, :])
+        target_mean_L = torch.mean(lab_target[0, :, :])
+        target_mean_a = torch.mean(lab_target[1, :, :])
+        target_mean_b = torch.mean(lab_target[2, :, :])
+        target_std_L = torch.std(lab_target[0, :, :])
+        target_std_a = torch.std(lab_target[1, :, :])
+        target_std_b = torch.std(lab_target[2, :, :])
+        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L
+        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a
+        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b
+        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))
+        return mapped_image
+
+
+
+ +
+ +
+ + +

+ crop_steerable_pyramid_filters(filters, size) + +

+ + +
+ +

Given original 9x9 NYU filters, this crops them to the desired size. +The size must be an odd number >= 3 +Note this only crops the h0, l0 and band filters (not the l downsampling filter)

+ + +

Parameters:

+
    +
  • + filters + – +
    +
            Filters to crop (should in format used by get_steerable_pyramid_filters.)
    +
    +
    +
  • +
  • + size + – +
    +
            Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +filters ( dict of torch.tensor +) – +
    +

    The cropped filters.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/steerable_pyramid_filters.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def crop_steerable_pyramid_filters(filters, size):
+    """
+    Given original 9x9 NYU filters, this crops them to the desired size.
+    The size must be an odd number >= 3
+    Note this only crops the h0, l0 and band filters (not the l downsampling filter)
+
+    Parameters
+    ----------
+    filters     : dict of torch.tensor
+                    Filters to crop (should in format used by get_steerable_pyramid_filters.)
+    size        : int
+                    Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.
+
+    Returns
+    -------
+    filters     : dict of torch.tensor
+                    The cropped filters.
+    """
+    assert(size >= 3)
+    assert(size % 2 == 1)
+    r = (size-1) // 2
+
+    def crop_filter(filter, r, normalise=True):
+        r2 = (filter.size(-1)-1)//2
+        filter = filter[:, :, r2-r:r2+r+1, r2-r:r2+r+1]
+        if normalise:
+            filter -= torch.sum(filter)
+        return filter
+
+    filters["h0"] = crop_filter(filters["h0"], r, normalise=False)
+    sum_l = torch.sum(filters["l"])
+    filters["l"] = crop_filter(filters["l"], 6, normalise=False)
+    filters["l"] *= sum_l / torch.sum(filters["l"])
+    sum_l0 = torch.sum(filters["l0"])
+    filters["l0"] = crop_filter(filters["l0"], 2, normalise=False)
+    filters["l0"] *= sum_l0 / torch.sum(filters["l0"])
+    for b in range(len(filters["b"])):
+        filters["b"][b] = crop_filter(filters["b"][b], r, normalise=True)
+    return filters
+
+
+
+ +
+ +
+ + +

+ get_steerable_pyramid_filters(size, n_orientations, filter_type) + +

+ + +
+ +

This returns filters for a real-valued steerable pyramid.

+ + +

Parameters:

+
    +
  • + size + – +
    +
                Width of the filters (e.g. 3 will return 3x3 filters)
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                Number of oriented band filters
    +
    +
    +
  • +
  • + filter_type + – +
    +
                This can be used to select between the original NYU filters and cropped or trained alternatives.
    +            full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
    +            cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
    +            trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +filters ( dict of torch.tensor +) – +
    +

    The steerable pyramid filters. Returned as a dict with the following keys: +"l" The lowpass downsampling filter +"l0" The lowpass residual filter +"h0" The highpass residual filter +"b" The band filters (a list of torch.tensor filters, one for each orientation).

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/steerable_pyramid_filters.py +
 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
def get_steerable_pyramid_filters(size, n_orientations, filter_type):
+    """
+    This returns filters for a real-valued steerable pyramid.
+
+    Parameters
+    ----------
+
+    size            : int
+                        Width of the filters (e.g. 3 will return 3x3 filters)
+    n_orientations  : int
+                        Number of oriented band filters
+    filter_type     :  str
+                        This can be used to select between the original NYU filters and cropped or trained alternatives.
+                        full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
+                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+
+    Returns
+    -------
+    filters         : dict of torch.tensor
+                        The steerable pyramid filters. Returned as a dict with the following keys:
+                        "l" The lowpass downsampling filter
+                        "l0" The lowpass residual filter
+                        "h0" The highpass residual filter
+                        "b" The band filters (a list of torch.tensor filters, one for each orientation).
+    """
+
+    if filter_type != "full" and filter_type != "cropped" and filter_type != "trained":
+        raise Exception(
+            "Unknown filter type %s! Only filter types are full, cropped or trained." % filter_type)
+
+    filters = {}
+    if n_orientations == 1:
+        filters["l"] = torch.tensor([
+            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -
+                1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04],
+            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
+                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
+            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
+                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
+            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
+                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
+            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
+                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
+            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
+                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
+            [-1.871920e-03, -6.948900e-03, -3.781600e-03, 2.449600e-02, 7.624674e-02, 1.348999e-01, 1.576508e-01,
+                1.348999e-01, 7.624674e-02, 2.449600e-02, -3.781600e-03, -6.948900e-03, -1.871920e-03],
+            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
+                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
+            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
+                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
+            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
+                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
+            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
+                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
+            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
+                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
+            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04]]
+        ).reshape(1, 1, 13, 13)
+        filters["l0"] = torch.tensor([
+            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -
+                3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04],
+            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
+                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
+            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
+                6.441488e-02, -1.344160e-02, -3.725800e-04],
+            [-3.743860e-03, -7.563200e-03, 1.524935e-01, 3.153017e-01,
+                1.524935e-01, -7.563200e-03, -3.743860e-03],
+            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
+                6.441488e-02, -1.344160e-02, -3.725800e-04],
+            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
+                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
+            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04]]
+        ).reshape(1, 1, 7, 7)
+        filters["h0"] = torch.tensor([
+            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -
+                2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04],
+            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
+                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
+            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
+                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
+            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
+                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
+            [-2.406600e-04, -3.732100e-04, -2.420138e-02, -9.623594e-02,
+                8.554893e-01, -9.623594e-02, -2.420138e-02, -3.732100e-04, -2.406600e-04],
+            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
+                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
+            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
+                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
+            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
+                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
+            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor([
+            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -
+            1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05,
+            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
+            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
+            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
+            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
+            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
+            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
+            -1.009473e-02, -9.091950e-03, -3.487008e-02, -3.158605e-02, 9.464195e-01, -
+            3.158605e-02, -3.487008e-02, -9.091950e-03, -1.009473e-02,
+            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
+            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
+            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
+            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
+            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
+            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
+            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 2:
+        filters["l"] = torch.tensor(
+            [[-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05],
+             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
+                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
+             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
+                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
+             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
+                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
+             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
+                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
+             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
+                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
+             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
+                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
+             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
+                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
+             [1.262000e-03, 5.589400e-03, 5.538200e-04, -9.637220e-03, -1.648568e-02, 1.438512e-02, 9.058014e-02, 1.773651e-01, 2.120374e-01,
+                 1.773651e-01, 9.058014e-02, 1.438512e-02, -1.648568e-02, -9.637220e-03, 5.538200e-04, 5.589400e-03, 1.262000e-03],
+             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
+                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
+             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
+                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
+             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
+                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
+             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
+                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
+             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
+                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
+             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
+                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
+             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
+                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
+             [-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05]]
+        ).reshape(1, 1, 17, 17)
+        filters["l0"] = torch.tensor(
+            [[-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05],
+             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
+                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
+             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
+                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
+             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
+                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
+             [2.524010e-03, 1.107620e-03, -3.297137e-02, 1.811603e-01, 4.376250e-01,
+                 1.811603e-01, -3.297137e-02, 1.107620e-03, 2.524010e-03],
+             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
+                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
+             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
+                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
+             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
+                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
+             [-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05]]
+        ).reshape(1, 1, 9, 9)
+        filters["h0"] = torch.tensor(
+            [[-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04],
+             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
+                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
+             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
+                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
+             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
+                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
+             [-1.166810e-03, 1.098012e-02, -5.897780e-03, -1.562827e-01,
+                 7.282593e-01, -1.562827e-01, -5.897780e-03, 1.098012e-02, -1.166810e-03],
+             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
+                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
+             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
+                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
+             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
+                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
+             [-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor(
+            [6.125880e-03, -8.052600e-03, -2.103714e-02, -1.536890e-02, -1.851466e-02, -1.536890e-02, -2.103714e-02, -8.052600e-03, 6.125880e-03,
+             -1.287416e-02, -9.611520e-03, 1.023569e-02, 6.009450e-03, 1.872620e-03, 6.009450e-03, 1.023569e-02, -
+             9.611520e-03, -1.287416e-02,
+             -5.641530e-03, 4.168400e-03, -2.382180e-02, -5.375324e-02, -
+             2.076086e-02, -5.375324e-02, -2.382180e-02, 4.168400e-03, -5.641530e-03,
+             -8.957260e-03, -1.751170e-03, -1.836909e-02, 1.265655e-01, 2.996168e-01, 1.265655e-01, -
+             1.836909e-02, -1.751170e-03, -8.957260e-03,
+             0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
+             8.957260e-03, 1.751170e-03, 1.836909e-02, -1.265655e-01, -
+             2.996168e-01, -1.265655e-01, 1.836909e-02, 1.751170e-03, 8.957260e-03,
+             5.641530e-03, -4.168400e-03, 2.382180e-02, 5.375324e-02, 2.076086e-02, 5.375324e-02, 2.382180e-02, -
+             4.168400e-03, 5.641530e-03,
+             1.287416e-02, 9.611520e-03, -1.023569e-02, -6.009450e-03, -
+             1.872620e-03, -6.009450e-03, -1.023569e-02, 9.611520e-03, 1.287416e-02,
+             -6.125880e-03, 8.052600e-03, 2.103714e-02, 1.536890e-02, 1.851466e-02, 1.536890e-02, 2.103714e-02, 8.052600e-03, -6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [-6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03,
+             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
+             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
+             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
+             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
+             1.536890e-02, -6.009450e-03, 5.375324e-02, -
+             1.265655e-01, 0.000000e+00, 1.265655e-01, -
+             5.375324e-02, 6.009450e-03, -1.536890e-02,
+             1.851466e-02, -1.872620e-03, 2.076086e-02, -
+             2.996168e-01, 0.000000e+00, 2.996168e-01, -
+             2.076086e-02, 1.872620e-03, -1.851466e-02,
+             1.536890e-02, -6.009450e-03, 5.375324e-02, -
+             1.265655e-01, 0.000000e+00, 1.265655e-01, -
+             5.375324e-02, 6.009450e-03, -1.536890e-02,
+             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
+             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
+             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
+             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
+             -6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 4:
+        filters["l"] = torch.tensor([
+            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4,
+                1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5],
+            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
+                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
+            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
+                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
+            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
+                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
+            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
+                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
+            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
+                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
+            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
+                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
+            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
+                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
+            [1.2619999470E-3, 5.5893999524E-3, 5.5381999118E-4, -9.6372198313E-3, -1.6485679895E-2, 1.4385120012E-2, 9.0580143034E-2, 0.1773651242, 0.2120374441,
+                0.1773651242, 9.0580143034E-2, 1.4385120012E-2, -1.6485679895E-2, -9.6372198313E-3, 5.5381999118E-4, 5.5893999524E-3, 1.2619999470E-3],
+            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
+                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
+            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
+                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
+            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
+                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
+            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
+                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
+            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
+                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
+            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
+                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
+            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
+                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
+            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4, 1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5]]
+        ).reshape(1, 1, 17, 17)
+        filters["l0"] = torch.tensor([
+            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4,
+                2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5],
+            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
+                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
+            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
+                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
+            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
+                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
+            [2.5240099058E-3, 1.1076199589E-3, -3.2971370965E-2, 0.1811603010, 0.4376249909,
+                0.1811603010, -3.2971370965E-2, 1.1076199589E-3, 2.5240099058E-3],
+            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
+                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
+            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
+                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
+            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
+                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
+            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4, 2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5]]
+        ).reshape(1, 1, 9, 9)
+        filters["h0"] = torch.tensor([
+            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3,
+                2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4],
+            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
+                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
+            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
+                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
+            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
+                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
+            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
+                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
+            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
+                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
+            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
+                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
+            [2.0898699295E-3, 1.9755100366E-3, -8.3368999185E-4, -5.1014497876E-3, 7.3032598011E-3, -9.3812197447E-3, -0.1554141641,
+                0.7303866148, -0.1554141641, -9.3812197447E-3, 7.3032598011E-3, -5.1014497876E-3, -8.3368999185E-4, 1.9755100366E-3, 2.0898699295E-3],
+            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
+                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
+            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
+                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
+            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
+                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
+            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
+                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
+            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
+                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
+            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
+                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
+            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3, 2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4]]
+        ).reshape(1, 1, 15, 15)
+        filters["b"] = []
+        filters["b"].append(torch.tensor(
+            [-8.1125000725E-4, 4.4451598078E-3, 1.2316980399E-2, 1.3955879956E-2,  1.4179450460E-2, 1.3955879956E-2, 1.2316980399E-2, 4.4451598078E-3, -8.1125000725E-4,
+             3.9103501476E-3, 4.4565401040E-3, -5.8724298142E-3, -2.8760801069E-3, 8.5267601535E-3, -
+             2.8760801069E-3, -5.8724298142E-3, 4.4565401040E-3, 3.9103501476E-3,
+             1.3462699717E-3, -3.7740699481E-3, 8.2581602037E-3, 3.9442278445E-2, 5.3605638444E-2, 3.9442278445E-2, 8.2581602037E-3, -
+             3.7740699481E-3, 1.3462699717E-3,
+             7.4700999539E-4, -3.6522001028E-4, -2.2522680461E-2, -0.1105690673, -
+             0.1768419296, -0.1105690673, -2.2522680461E-2, -3.6522001028E-4, 7.4700999539E-4,
+             0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
+             -7.4700999539E-4, 3.6522001028E-4, 2.2522680461E-2, 0.1105690673, 0.1768419296, 0.1105690673, 2.2522680461E-2, 3.6522001028E-4, -7.4700999539E-4,
+             -1.3462699717E-3, 3.7740699481E-3, -8.2581602037E-3, -3.9442278445E-2, -
+             5.3605638444E-2, -3.9442278445E-2, -
+             8.2581602037E-3, 3.7740699481E-3, -1.3462699717E-3,
+             -3.9103501476E-3, -4.4565401040E-3, 5.8724298142E-3, 2.8760801069E-3, -
+             8.5267601535E-3, 2.8760801069E-3, 5.8724298142E-3, -
+             4.4565401040E-3, -3.9103501476E-3,
+             8.1125000725E-4, -4.4451598078E-3, -1.2316980399E-2, -1.3955879956E-2, -1.4179450460E-2, -1.3955879956E-2, -1.2316980399E-2, -4.4451598078E-3, 8.1125000725E-4]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3,
+             8.2846998703E-4, 0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
+             2.0865600090E-3, 2.3298799060E-3, -
+             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
+             5.7109999034E-5, 9.7479997203E-4, 0.0000000000, -1.2145539746E-2, -
+             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
+             4.4814897701E-3, 1.4807609841E-2,
+             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
+             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
+             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
+             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
+             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, 0.0000000000, -
+             1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
+             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
+             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
+             9.7479997203E-4, -5.7109999034E-5,
+             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
+             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, 0.0000000000, -8.2846998703E-4,
+             3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4,
+             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
+             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
+             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
+             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
+             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
+             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
+             -1.4179450460E-2, -8.5267601535E-3, -5.3605638444E-2, 0.1768419296, 0.0000000000, -
+             0.1768419296, 5.3605638444E-2, 8.5267601535E-3, 1.4179450460E-2,
+             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
+             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
+             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
+             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
+             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
+             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
+             8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000,
+             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
+             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, -
+             0.0000000000, -8.2846998703E-4,
+             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
+             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
+             9.7479997203E-4, -5.7109999034E-5,
+             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, -
+             0.0000000000, -1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
+             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
+             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
+             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
+             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
+             5.7109999034E-5, 9.7479997203E-4, -0.0000000000, -1.2145539746E-2, -
+             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
+             4.4814897701E-3, 1.4807609841E-2,
+             8.2846998703E-4, -0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
+             2.0865600090E-3, 2.3298799060E-3, -
+             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
+             0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 6:
+        filters["l"] = 2 * torch.tensor([
+            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -
+                0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404],
+            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
+                0.00410600, -0.00661117, -0.00523281, -0.00244917],
+            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
+                0.03277038, 0.01396746, -0.00661117, -0.00387812],
+            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
+                0.06426333, 0.03277038, 0.00410600, -0.00944432],
+            [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 0.10096540,
+                0.08169618, 0.03981393, 0.01002988, -0.00962054],
+            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
+                0.06426333, 0.03277038, 0.00410600, -0.00944432],
+            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
+                0.03277038, 0.01396746, -0.00661117, -0.00387812],
+            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
+                0.00410600, -0.00661117, -0.00523281, -0.00244917],
+            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404]]
+        ).reshape(1, 1, 9, 9)
+        filters["l0"] = torch.tensor([
+            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],
+            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
+            [-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],
+            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
+            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614]]
+        ).reshape(1, 1, 5, 5)
+        filters["h0"] = torch.tensor([
+            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -
+                0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429],
+            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
+                0.00631653, -0.00243812, -0.00350017, -0.00113093],
+            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
+                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
+            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
+                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
+            [-0.00080639, 0.01261227, -0.00981051, -0.11435863,
+                0.81380200, -0.11435863, -0.00981051, 0.01261227, -0.00080639],
+            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
+                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
+            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
+                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
+            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
+                0.00631653, -0.00243812, -0.00350017, -0.00113093],
+            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor([
+            0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 0.00496194, 0.00277643,
+            -0.00986904, -0.00893064, 0.01189859, 0.02755155, 0.01189859, -0.00893064, -0.00986904,
+            -0.01021852, -0.03075356, -0.08226445, -
+            0.11732297, -0.08226445, -0.03075356, -0.01021852,
+            0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
+            0.01021852, 0.03075356, 0.08226445, 0.11732297, 0.08226445, 0.03075356, 0.01021852,
+            0.00986904, 0.00893064, -0.01189859, -
+            0.02755155, -0.01189859, 0.00893064, 0.00986904,
+            -0.00277643, -0.00496194, -0.01026699, -0.01455399, -0.01026699, -0.00496194, -0.00277643]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982,
+            -0.00358461, -0.01977507, -0.04084211, -
+            0.00228219, 0.03930573, 0.01161195, 0.00128000,
+            0.01047717, 0.01486305, -0.04819057, -
+            0.12227230, -0.05394139, 0.00853965, -0.00459034,
+            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
+            0.09454202, -0.04435647, -0.00790407,
+            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
+            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
+            -0.01166982, -0.00285723, -0.00182078, -0.01124321, 0.00073141, 0.00640815, 0.00343249]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            0.00343249, 0.00358461, -0.01047717, -
+            0.00790407, -0.00459034, 0.00128000, 0.01166982,
+            0.00640815, 0.01977507, -0.01486305, -
+            0.04435647, 0.00853965, 0.01161195, 0.00285723,
+            0.00073141, 0.04084211, 0.04819057, -
+            0.09454202, -0.05394139, 0.03930573, 0.00182078,
+            -0.01124321, 0.00228219, 0.12227230, -
+            0.00000000, -0.12227230, -0.00228219, 0.01124321,
+            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
+            0.04819057, -0.04084211, -0.00073141,
+            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
+            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [-0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643,
+             -0.00496194, 0.00893064, 0.03075356, -
+             0.00000000, -0.03075356, -0.00893064, 0.00496194,
+             -0.01026699, -0.01189859, 0.08226445, -
+             0.00000000, -0.08226445, 0.01189859, 0.01026699,
+             -0.01455399, -0.02755155, 0.11732297, -
+             0.00000000, -0.11732297, 0.02755155, 0.01455399,
+             -0.01026699, -0.01189859, 0.08226445, -
+             0.00000000, -0.08226445, 0.01189859, 0.01026699,
+             -0.00496194, 0.00893064, 0.03075356, -
+             0.00000000, -0.03075356, -0.00893064, 0.00496194,
+             -0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249,
+            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
+            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
+            0.04819057, -0.04084211, -0.00073141,
+            -0.01124321, 0.00228219, 0.12227230, -
+            0.00000000, -0.12227230, -0.00228219, 0.01124321,
+            0.00073141, 0.04084211, 0.04819057, -
+            0.09454202, -0.05394139, 0.03930573, 0.00182078,
+            0.00640815, 0.01977507, -0.01486305, -
+            0.04435647, 0.00853965, 0.01161195, 0.00285723,
+            0.00343249, 0.00358461, -0.01047717, -0.00790407, -0.00459034, 0.00128000, 0.01166982]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.01166982, -0.00285723, -0.00182078, -
+            0.01124321, 0.00073141, 0.00640815, 0.00343249,
+            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
+            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
+            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
+            0.09454202, -0.04435647, -0.00790407,
+            0.01047717, 0.01486305, -0.04819057, -
+            0.12227230, -0.05394139, 0.00853965, -0.00459034,
+            -0.00358461, -0.01977507, -0.04084211, -
+            0.00228219, 0.03930573, 0.01161195, 0.00128000,
+            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+
+    else:
+        raise Exception(
+            "Steerable filters not implemented for %d orientations" % n_orientations)
+
+    if filter_type == "trained":
+        if size == 5:
+            # TODO maybe also train h0 and l0 filters
+            filters = crop_steerable_pyramid_filters(filters, 5)
+            filters["b"][0] = torch.tensor([
+                [-0.0356752239, -0.0223877281, -0.0009542659,
+                    0.0244821459, 0.0322226137],
+                [-0.0593218654,  0.1245803162, -
+                    0.0023863907, -0.1230178699, 0.0589442067],
+                [-0.0281576272,  0.2976626456, -
+                    0.0020888755, -0.2953369915, 0.0284542721],
+                [-0.0586092323,  0.1251581162, -
+                    0.0024624448, -0.1227868199, 0.0587830991],
+                [-0.0327464789, -0.0223652460, -
+                    0.0042342511,  0.0245472137, 0.0359398536]
+            ]).reshape(1, 1, 5, 5)
+            filters["b"][1] = torch.tensor([
+                [3.9758663625e-02,  6.0679119080e-02,  3.0146904290e-02,
+                    6.1198268086e-02,  3.6218870431e-02],
+                [2.3255519569e-02, -1.2505133450e-01, -
+                    2.9738345742e-01, -1.2518258393e-01,  2.3592948914e-02],
+                [-1.3602430699e-03, -1.2058277935e-04,  2.6399988565e-04, -
+                    2.3791544663e-04,  1.8450465286e-03],
+                [-2.1563466638e-02,  1.2572696805e-01,  2.9745018482e-01,
+                    1.2458638102e-01, -2.3847281933e-02],
+                [-3.7941932678e-02, -6.1060950160e-02, -
+                    2.9489086941e-02, -6.0411967337e-02, -3.8459088653e-02]
+            ]).reshape(1, 1, 5, 5)
+
+            # Below filters were optimised on 09/02/2021
+            # 20K iterations with multiple images at more scales.
+            filters["b"][0] = torch.tensor([
+                [-4.5508436859e-02, -2.1767273545e-02, -1.9399923622e-04,
+                    2.1200872958e-02,  4.5475799590e-02],
+                [-6.3554823399e-02,  1.2832683325e-01, -
+                    5.3858719184e-05, -1.2809979916e-01,  6.3842624426e-02],
+                [-3.4809380770e-02,  2.9954621196e-01,  2.9066693969e-05, -
+                    2.9957753420e-01,  3.4806568176e-02],
+                [-6.3934154809e-02,  1.2806062400e-01,  9.0917674243e-05, -
+                    1.2832444906e-01,  6.3572973013e-02],
+                [-4.5492250472e-02, -2.1125273779e-02,  4.2229349492e-04,
+                    2.1804777905e-02,  4.5236673206e-02]
+            ]).reshape(1, 1, 5, 5)
+            filters["b"][1] = torch.tensor([
+                [4.8947390169e-02,  6.3575074077e-02,  3.4955859184e-02,
+                    6.4085893333e-02,  4.9838040024e-02],
+                [2.2061849013e-02, -1.2936264277e-01, -
+                    3.0093491077e-01, -1.2997294962e-01,  2.0597217605e-02],
+                [-5.1290717238e-05, -1.7305796064e-05,  2.0256420612e-05, -
+                    1.1864109547e-04,  7.3973249528e-05],
+                [-2.0749464631e-02,  1.2988376617e-01,  3.0080935359e-01,
+                    1.2921217084e-01, -2.2159902379e-02],
+                [-4.9614857882e-02, -6.4021714032e-02, -
+                    3.4676689655e-02, -6.3446544111e-02, -4.8282280564e-02]
+            ]).reshape(1, 1, 5, 5)
+
+            # Trained on 17/02/2021 to match fourier pyramid in spatial domain
+            filters["b"][0] = torch.tensor([
+                [3.3370e-02,  9.3934e-02, -3.5810e-04, -9.4038e-02, -3.3115e-02],
+                [1.7716e-01,  3.9378e-01,  6.8461e-05, -3.9343e-01, -1.7685e-01],
+                [2.9213e-01,  6.1042e-01,  7.0654e-04, -6.0939e-01, -2.9177e-01],
+                [1.7684e-01,  3.9392e-01,  1.0517e-03, -3.9268e-01, -1.7668e-01],
+                [3.3000e-02,  9.4029e-02,  7.3565e-04, -9.3366e-02, -3.3008e-02]
+            ]).reshape(1, 1, 5, 5) * 0.1
+
+            filters["b"][1] = torch.tensor([
+                [0.0331,  0.1763,  0.2907,  0.1753,  0.0325],
+                [0.0941,  0.3932,  0.6079,  0.3904,  0.0922],
+                [0.0008,  0.0009, -0.0010, -0.0025, -0.0015],
+                [-0.0929, -0.3919, -0.6097, -0.3944, -0.0946],
+                [-0.0328, -0.1760, -0.2915, -0.1768, -0.0333]
+            ]).reshape(1, 1, 5, 5) * 0.1
+
+        else:
+            raise Exception(
+                "Trained filters not implemented for size %d" % size)
+
+    if filter_type == "cropped":
+        filters = crop_steerable_pyramid_filters(filters, size)
+
+    return filters
+
+
+
+ +
+ +
+ + +

+ hsv_to_rgb(image) + +

+ + +
+ +

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_rgb ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def hsv_to_rgb(image):
+
+    """
+    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_rgb       : torch.tensor
+                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    h = image[..., 0, :, :] / (2 * math.pi)
+    s = image[..., 1, :, :]
+    v = image[..., 2, :, :]
+    hi = torch.floor(h * 6) % 6
+    f = ((h * 6) % 6) - hi
+    one = torch.tensor(1.0)
+    p = v * (one - s)
+    q = v * (one - f * s)
+    t = v * (one - (one - f) * s)
+    hi = hi.long()
+    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)
+    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
+    image_rgb = torch.gather(image_rgb, -3, indices)
+    return image_rgb
+
+
+
+ +
+ +
+ + +

+ lab_to_srgb(image) + +

+ + +
+ +

Definition to convert LAB space to SRGB color space.

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in LAB color space[3 x m x n]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_srgb ( tensor +) – +
    +

    Output image in SRGB color space [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def lab_to_srgb(image):
+    """
+    Definition to convert LAB space to SRGB color space. 
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in LAB color space[3 x m x n]
+    Returns
+    -------
+    image_srgb     : torch.tensor
+                      Output image in SRGB color space [3 x m x n].
+    """
+
+    if image.shape[-1] == 3:
+        input_color = image.permute(2, 0, 1)  # C(H*W)
+    else:
+        input_color = image
+    # lab ---> xyz
+    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)
+    y = (input_color[0:1, :, :] + 16) / 116
+    a =  input_color[1:2, :, :] / 500
+    b =  input_color[2:3, :, :] / 200
+    x = y + a
+    z = y - b
+    xyz = torch.cat((x, y, z), 0)
+    delta = 6 / 29
+    factor = 3 * delta * delta
+    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))
+    xyz_color = xyz * reference_illuminant
+    # xyz ---> linear rgb
+    a11 = 3.241003275
+    a12 = -1.537398934
+    a13 = -0.498615861
+    a21 = -0.969224334
+    a22 = 1.875930071
+    a23 = 0.041554224
+    a31 = 0.055639423
+    a32 = -0.204011202
+    a33 = 1.057148933
+    A = torch.tensor([[a11, a12, a13],
+                  [a21, a22, a23],
+                  [a31, a32, a33]], dtype=torch.float32)
+
+    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)
+    linear_rgb_color = torch.matmul(A, xyz_color)
+    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)
+    # linear rgb ---> srgb
+    limit = 0.0031308
+    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)
+    return image_srgb
+
+
+
+ +
+ +
+ + +

+ linear_rgb_to_rgb(image, threshold=0.0031308) + +

+ + +
+ +

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
  • + threshold + – +
    +
              Threshold used in calculations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def linear_rgb_to_rgb(image, threshold = 0.0031308):
+    """
+    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+    threshold       : float
+                      Threshold used in calculations.
+
+    Returns
+    -------
+    image_linear    : torch.tensor
+                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)
+    return image_linear
+
+
+
+ +
+ +
+ + +

+ linear_rgb_to_xyz(image) + +

+ + +
+ +

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_xyz ( tensor +) – +
    +

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def linear_rgb_to_xyz(image):
+    """
+    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_xyz       : torch.tensor
+                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    a11 = 0.412453
+    a12 = 0.357580
+    a13 = 0.180423
+    a21 = 0.212671
+    a22 = 0.715160
+    a23 = 0.072169
+    a31 = 0.019334
+    a32 = 0.119193
+    a33 = 0.950227
+    M = torch.tensor([[a11, a12, a13], 
+                      [a21, a22, a23],
+                      [a31, a32, a33]])
+    size = image.size()
+    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)
+    image_xyz = torch.matmul(M, image)
+    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])
+    return image_xyz
+
+
+
+ +
+ +
+ + +

+ make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6) + +

+ + +
+ +

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to +a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction.

+ + +

Parameters:

+
    +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +map ( tensor +) – +
    +

    The computed 3D location map, of size 3xWxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
+    """ 
+    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to
+    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction.
+
+    Parameters
+    ----------
+
+    image_pixel_size        : tuple of ints 
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float 
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    map                     : torch.tensor
+                                The computed 3D location map, of size 3xWxH.
+    """
+    real_image_height = (real_image_width /
+                         image_pixel_size[-1]) * image_pixel_size[-2]
+    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width
+    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)
+    y_coords = torch.linspace(-0.5, 0.5,
+                              image_pixel_size[-2])*real_image_height
+    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])
+    z_coords = torch.ones(
+        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance
+
+    return torch.cat([x_coords, y_coords, z_coords], dim=0)
+
+
+
+ +
+ +
+ + +

+ make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6) + +

+ + +
+ +

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to +a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction. Output in radians.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +eccentricity_map ( tensor +) – +
    +

    The computed eccentricity map, of size WxH.

    +
    +
  • +
  • +distance_map ( tensor +) – +
    +

    The computed distance map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
+    """ 
+    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to
+    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction. Output in radians.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    eccentricity_map        : torch.tensor
+                                The computed eccentricity map, of size WxH.
+    distance_map            : torch.tensor
+                                The computed distance map, of size WxH.
+    """
+    real_image_height = (real_image_width /
+                         image_pixel_size[-1]) * image_pixel_size[-2]
+    location_map = make_3d_location_map(
+        image_pixel_size, real_image_width, real_viewing_distance)
+    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))
+    direction_map = location_map / distance_map
+
+    gaze_location_3d = torch.tensor([
+        (gaze_location[0]*2 - 1)*real_image_width*0.5,
+        (gaze_location[1]*2 - 1)*real_image_height*0.5,
+        real_viewing_distance])
+    gaze_dir = gaze_location_3d / \
+        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))
+    gaze_dir = gaze_dir[:, None, None]
+
+    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)
+    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)
+    eccentricity_map = torch.acos(dot_prod_map)
+
+    return eccentricity_map, distance_map
+
+
+
+ +
+ +
+ + +

+ make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic') + +

+ + +
+ +

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from +to achieve the correct pooling region areas.

+ + +

Parameters:

+
    +
  • + gaze_angles + – +
    +
                    Gaze direction expressed as angles, in radians.
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                    Dimensions of the image in pixels, as a tuple of (height, width)
    +
    +
    +
  • +
  • + alpha + – +
    +
                    Parameter controlling extent of foveation
    +
    +
    +
  • +
  • + mode + – +
    +
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size HxW.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
+    """ 
+    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
+    to achieve the correct pooling region areas.
+
+    Parameters
+    ----------
+
+    gaze_angles         : tuple of 2 floats
+                            Gaze direction expressed as angles, in radians.
+    image_pixel_size    : tuple of 2 ints
+                            Dimensions of the image in pixels, as a tuple of (height, width)
+    alpha               : float
+                            Parameter controlling extent of foveation
+    mode                : str
+                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size HxW.
+    """
+    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)
+    import matplotlib.pyplot as plt
+    pooling_lod = torch.log2(1e-6+pooling_pixel)
+    pooling_lod[pooling_lod < 0] = 0
+    return pooling_lod
+
+
+
+ +
+ +
+ + +

+ make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic') + +

+ + +
+ +

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. +Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, +the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image +corresponds to pitch, ranging from -pi/2 to pi/2.

+

In this setup real_image_width and real_viewing_distance have no effect.

+

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

+ + +

Parameters:

+
    +
  • + gaze_angles + – +
    +
                    Gaze direction expressed as angles, in radians.
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                    Dimensions of the image in pixels, as a tuple of (height, width)
    +
    +
    +
  • +
  • + alpha + – +
    +
                    Parameter controlling extent of foveation
    +
    +
    +
  • +
  • + mode + – +
    +
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
+    """
+    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.
+    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, 
+    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image
+    corresponds to pitch, ranging from -pi/2 to pi/2.
+
+    In this setup real_image_width and real_viewing_distance have no effect.
+
+    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).
+
+    Parameters
+    ----------
+
+    gaze_angles         : tuple of 2 floats
+                            Gaze direction expressed as angles, in radians.
+    image_pixel_size    : tuple of 2 ints
+                            Dimensions of the image in pixels, as a tuple of (height, width)
+    alpha               : float
+                            Parameter controlling extent of foveation
+    mode                : str
+                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
+    """
+    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])
+
+    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])
+    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]
+    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])
+    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]
+
+    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])
+
+    # Work out the pooling region diameter in radians
+    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)
+    eccentricity = torch.acos(view_dot_dir)
+    pooling_rad = alpha * eccentricity
+    if mode == "quadratic":
+        pooling_rad *= eccentricity
+
+    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes
+    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse
+    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)
+    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)
+    pooling_axis_x = pooling_rad * pixels_per_rad_x
+    pooling_axis_y = pooling_rad * pixels_per_rad_y
+    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25
+
+    # Now finally find the length of the side of a square of the same area.
+    size = torch.sqrt(torch.abs(area))
+    return size
+
+
+
+ +
+ +
+ + +

+ make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic') + +

+ + +
+ +

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from +to achieve the correct pooling region areas.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
+    """ 
+    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
+    to achieve the correct pooling region areas.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size WxH.
+    """
+    pooling_pixel = make_pooling_size_map_pixels(
+        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)
+    pooling_lod = torch.log2(1e-6+pooling_pixel)
+    pooling_lod[pooling_lod < 0] = 0
+    return pooling_lod
+
+
+
+ +
+ +
+ + +

+ make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic') + +

+ + +
+ +

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to +a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity +(also in radians).

+

Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
+    """ 
+    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to
+    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity
+    (also in radians). 
+
+    Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size WxH.
+    """
+    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(
+        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)
+    eccentricity_centre, _ = make_eccentricity_distance_maps(
+        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)
+    pooling_rad = alpha * eccentricity
+    if mode == "quadratic":
+        pooling_rad *= eccentricity
+    angle_min = eccentricity_centre - pooling_rad*0.5
+    angle_max = eccentricity_centre + pooling_rad*0.5
+    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \
+        real_viewing_distance
+    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)
+    area = math.pi * major_axis * minor_axis * 0.25
+    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number
+    area = torch.abs(area)
+    pooling_real = torch.sqrt(area)
+    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]
+    return pooling_pixel
+
+
+
+ +
+ +
+ + +

+ make_radial_map(size, gaze) + +

+ + +
+ +

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

+ + +

Parameters:

+
    +
  • + size + – +
    +
        Dimensions of the image
    +
    +
    +
  • +
  • + gaze + – +
    +
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_radial_map(size, gaze):
+    """ 
+    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.
+
+    Parameters
+    ----------
+
+    size    : tuple of ints
+                Dimensions of the image
+    gaze    : tuple of floats
+                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                image coordinates (ranging from 0 to 1)
+    """
+    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]
+    rows = torch.linspace(0, size[0], size[0])
+    rows = rows[:, None].repeat(1, size[1])
+    cols = torch.linspace(0, size[1], size[1])
+    cols = cols[None, :].repeat(size[0], 1)
+    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \
+        torch.pow(cols - pix_gaze[1], 2)
+    radii = torch.sqrt(dist_sq)
+    return radii/torch.max(radii)
+
+
+
+ +
+ +
+ + +

+ pad_image_for_pyramid(image, n_pyramid_levels) + +

+ + +
+ +

Pads an image to the extent necessary to compute a steerable pyramid of the input image. +This involves padding so both height and width are divisible by 2**n_pyramid_levels. +Uses reflection padding.

+ + +

Parameters:

+
    +
  • + image + – +
    +

    Image to pad, in NCHW format

    +
    +
  • +
  • + n_pyramid_levels + – +
    +

    Number of levels in the pyramid you plan to construct.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
def pad_image_for_pyramid(image, n_pyramid_levels):
+    """
+    Pads an image to the extent necessary to compute a steerable pyramid of the input image.
+    This involves padding so both height and width are divisible by 2**n_pyramid_levels.
+    Uses reflection padding.
+
+    Parameters
+    ----------
+
+    image: torch.tensor
+        Image to pad, in NCHW format
+    n_pyramid_levels: int
+        Number of levels in the pyramid you plan to construct.
+    """
+    min_divisor = 2 ** n_pyramid_levels
+    height = image.size(2)
+    width = image.size(3)
+    required_height = math.ceil(height / min_divisor) * min_divisor
+    required_width = math.ceil(width / min_divisor) * min_divisor
+    if required_height > height or required_width > width:
+        # We need to pad!
+        pad = torch.nn.ReflectionPad2d(
+            (0, 0, required_height-height, required_width-width))
+        return pad(image)
+    return image
+
+
+
+ +
+ +
+ + +

+ rgb_2_ycrcb(image) + +

+ + +
+ +

Converts an image from RGB colourspace to YCrCb colourspace.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ycrcb ( tensor +) – +
    +

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_2_ycrcb(image):
+    """
+    Converts an image from RGB colourspace to YCrCb colourspace.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
+
+    Returns
+    -------
+
+    ycrcb   : torch.tensor
+              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+       image = image.unsqueeze(0)
+    ycrcb = torch.zeros(image.size()).to(image.device)
+    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \
+        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
+    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])
+    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])
+    return ycrcb
+
+
+
+ +
+ +
+ + +

+ rgb_to_hsv(image, eps=1e-08) + +

+ + +
+ +

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_hsv ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_to_hsv(image, eps: float = 1e-8):
+
+    """
+    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_hsv       : torch.tensor
+                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    max_rgb, argmax_rgb = image.max(-3)
+    min_rgb, argmin_rgb = image.min(-3)
+    deltac = max_rgb - min_rgb
+    v = max_rgb
+    s = deltac / (max_rgb + eps)
+    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
+    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
+    h1 = bc - gc
+    h2 = (rc - bc) + 2.0 * deltac
+    h3 = (gc - rc) + 4.0 * deltac
+    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
+    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
+    h = (h / 6.0) % 1.0
+    h = 2.0 * math.pi * h 
+    image_hsv = torch.stack((h, s, v), dim=-3)
+    return image_hsv
+
+
+
+ +
+ +
+ + +

+ rgb_to_linear_rgb(image, threshold=0.0031308) + +

+ + +
+ +

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
  • + threshold + – +
    +
              Threshold used in calculations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear ( tensor +) – +
    +

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_to_linear_rgb(image, threshold = 0.0031308):
+    """
+    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+    threshold       : float
+                      Threshold used in calculations.
+
+    Returns
+    -------
+    image_linear    : torch.tensor
+                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)
+    return image_linear
+
+
+
+ +
+ +
+ + +

+ srgb_to_lab(image) + +

+ + +
+ +

Definition to convert SRGB space to LAB color space.

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in SRGB color space[3 x m x n]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_lab ( tensor +) – +
    +

    Output image in LAB color space [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def srgb_to_lab(image):    
+    """
+    Definition to convert SRGB space to LAB color space. 
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in SRGB color space[3 x m x n]
+    Returns
+    -------
+    image_lab       : torch.tensor
+                      Output image in LAB color space [3 x m x n].
+    """
+    if image.shape[-1] == 3:
+        input_color = image.permute(2, 0, 1)  # C(H*W)
+    else:
+        input_color = image
+    # rgb ---> linear rgb
+    limit = 0.04045        
+    # linear rgb ---> xyz
+    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)
+
+    a11 = 10135552 / 24577794
+    a12 = 8788810  / 24577794
+    a13 = 4435075  / 24577794
+    a21 = 2613072  / 12288897
+    a22 = 8788810  / 12288897
+    a23 = 887015   / 12288897
+    a31 = 1425312  / 73733382
+    a32 = 8788810  / 73733382
+    a33 = 70074185 / 73733382
+
+    A = torch.tensor([[a11, a12, a13],
+                    [a21, a22, a23],
+                    [a31, a32, a33]], dtype=torch.float32)
+
+    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)
+    xyz_color = torch.matmul(A, linrgb_color)
+    xyz_color = xyz_color.permute(1, 2, 0)
+    # xyz ---> lab
+    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)
+    input_color = xyz_color * inv_reference_illuminant
+    delta = 6 / 29
+    delta_square = delta * delta
+    delta_cube = delta * delta_square
+    factor = 1 / (3 * delta_square)
+
+    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))
+
+    l = 116 * input_color[1:2, :, :] - 16
+    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])
+    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])
+
+    image_lab = torch.cat((l, a, b), 0)
+    return image_lab    
+
+
+
+ +
+ +
+ + +

+ xyz_to_linear_rgb(image) + +

+ + +
+ +

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

+ + +

Parameters:

+
    +
  • + image + – +
    +
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear_rgb ( tensor +) – +
    +

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def xyz_to_linear_rgb(image):
+    """
+    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)
+
+    Parameters
+    ----------
+    image            : torch.tensor
+                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_linear_rgb : torch.tensor
+                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    a11 = 3.240479
+    a12 = -1.537150
+    a13 = -0.498535
+    a21 = -0.969256 
+    a22 = 1.875992 
+    a23 = 0.041556
+    a31 = 0.055648
+    a32 = -0.204043
+    a33 = 1.057311
+    M = torch.tensor([[a11, a12, a13], 
+                      [a21, a22, a23],
+                      [a31, a32, a33]])
+    size = image.size()
+    image = image.reshape(size[0], size[1], size[2]*size[3])
+    image_linear_rgb = torch.matmul(M, image)
+    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])
+    return image_linear_rgb
+
+
+
+ +
+ +
+ + +

+ ycrcb_2_rgb(image) + +

+ + +
+ +

Converts an image from YCrCb colourspace to RGB colourspace.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rgb ( tensor +) – +
    +

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def ycrcb_2_rgb(image):
+    """
+    Converts an image from YCrCb colourspace to RGB colourspace.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
+
+    Returns
+    -------
+    rgb     : torch.tensor
+              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+       image = image.unsqueeze(0)
+    rgb = torch.zeros(image.size(), device=image.device)
+    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)
+    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \
+        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)
+    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)
+    return rgb
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BlurLoss + + +

+ + +
+ + +

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

+

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

+

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/blur_loss.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
class BlurLoss():
+    """ 
+
+    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.
+
+    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.
+
+    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device("cpu"),
+                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
+        """
+        Parameters
+        ----------
+
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        blur_source             : bool
+                                    If true, blurs the source image as well as the target before computing the loss.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.device = device
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.mode = mode
+        self.blur = None
+        self.loss_func = torch.nn.MSELoss()
+        self.blur_source = blur_source
+        self.equi = equi
+
+    def blur_image(self, image, gaze):
+        if self.blur is None:
+            self.blur = RadiallyVaryingBlur()
+        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)
+
+    def __call__(self, image, target, gaze=[0.5, 0.5]):
+        """ 
+        Calculates the Blur Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze                : list
+                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("BlurLoss", image, target)
+        blurred_target = self.blur_image(target, gaze)
+        if self.blur_source:
+            blurred_image = self.blur_image(image, gaze)
+            return self.loss_func(blurred_image, blurred_target)
+        else:
+            return self.loss_func(image, blurred_target)
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5]) + +

+ + +
+ +

Calculates the Blur Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/blur_loss.py +
59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
def __call__(self, image, target, gaze=[0.5, 0.5]):
+    """ 
+    Calculates the Blur Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze                : list
+                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("BlurLoss", image, target)
+    blurred_target = self.blur_image(target, gaze)
+    if self.blur_source:
+        blurred_image = self.blur_image(image, gaze)
+        return self.loss_func(blurred_image, blurred_target)
+    else:
+        return self.loss_func(image, blurred_target)
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + blur_source + – +
    +
                        If true, blurs the source image as well as the target before computing the loss.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/blur_loss.py +
18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def __init__(self, device=torch.device("cpu"),
+             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic", blur_source=False, equi=False):
+    """
+    Parameters
+    ----------
+
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    blur_source             : bool
+                                If true, blurs the source image as well as the target before computing the loss.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.device = device
+    self.alpha = alpha
+    self.real_image_width = real_image_width
+    self.real_viewing_distance = real_viewing_distance
+    self.mode = mode
+    self.blur = None
+    self.loss_func = torch.nn.MSELoss()
+    self.blur_source = blur_source
+    self.equi = equi
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ display_color_hvs + + +

+ + +
+ + + + + + + +
+ Source code in odak/learn/perception/color_conversion.py +
 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
class display_color_hvs():
+
+    def __init__(
+                 self,
+                 resolution = [1920, 1080],
+                 distance_from_screen = 800,
+                 pixel_pitch = 0.311,
+                 read_spectrum = 'tensor',
+                 primaries_spectrum = torch.rand(3, 301),
+                 device = torch.device('cpu')):
+        '''
+        Parameters
+        ----------
+        resolution                  : list
+                                      Resolution of the display in pixels.
+        distance_from_screen        : int
+                                      Distance from the screen in mm.
+        pixel_pitch                 : float
+                                      Pixel pitch of the display in mm.
+        read_spectrum               : str
+                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
+        device                      : torch.device
+                                      Device to run the code on. Default is None which means the code will run on CPU.
+
+        '''
+        self.device = device
+        self.read_spectrum = read_spectrum
+        self.primaries_spectrum = primaries_spectrum.to(self.device)
+        self.resolution = resolution
+        self.distance_from_screen = distance_from_screen
+        self.pixel_pitch = pixel_pitch
+        self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
+        self.lms_tensor = self.construct_matrix_lms(
+                                                    self.l_normalized,
+                                                    self.m_normalized,
+                                                    self.s_normalized
+                                                   )   
+        self.primaries_tensor = self.construct_matrix_primaries(
+                                                                self.l_normalized,
+                                                                self.m_normalized,
+                                                                self.s_normalized
+                                                               )   
+        return
+
+
+    def __call__(self, input_image, ground_truth, gaze=None):
+        """
+        Evaluating an input image against a target ground truth image for a given gaze of a viewer.
+        """
+        lms_image_second = self.primaries_to_lms(input_image.to(self.device))
+        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
+        lms_image_third = self.second_to_third_stage(lms_image_second)
+        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
+        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
+        return loss_metamer_color
+
+
+    def initialize_cones_normalized(self):
+        """
+        Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 
+
+        Returns
+        -------
+        l_cone_n                     : torch.tensor
+                                       Normalised L cone distribution.
+        m_cone_n                     : torch.tensor
+                                       Normalised M cone distribution.
+        s_cone_n                     : torch.tensor
+                                       Normalised S cone distribution.
+        """
+        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+        dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
+        dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))
+        dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))
+
+        l_cone_n = dist_l / dist_l.max()
+        m_cone_n = dist_m / dist_m.max()
+        s_cone_n = dist_s / dist_s.max()
+        return l_cone_n, m_cone_n, s_cone_n
+
+
+    def initialize_rgb_backlight_spectrum(self):
+        """
+        Internal function to initialize baclight spectrum for color primaries. 
+
+        Returns
+        -------
+        red_spectrum                 : torch.tensor
+                                       Normalised backlight spectrum for red color primary.
+        green_spectrum               : torch.tensor
+                                       Normalised backlight spectrum for green color primary.
+        blue_spectrum                : torch.tensor
+                                       Normalised backlight spectrum for blue color primary.
+        """
+        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+        red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))
+        green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))
+        blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))
+
+        red_spectrum = red_spectrum / red_spectrum.max()
+        green_spectrum = green_spectrum / green_spectrum.max()
+        blue_spectrum = blue_spectrum / blue_spectrum.max()
+
+        return red_spectrum, green_spectrum, blue_spectrum
+
+
+    def initialize_random_spectrum_normalized(self, dataset):
+        """
+        Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 
+
+        Parameters
+        ----------
+        dataset                                : torch.tensor 
+                                                 spectrum value against wavelength 
+        """
+        dataset = torch.swapaxes(dataset, 0, 1)
+        x_spectrum = torch.linspace(400, 700, steps = 301) - 550
+        y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))
+        max_spectrum = torch.max(y_spectrum)
+        y_spectrum /= max_spectrum
+
+        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
+            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))
+
+        def function(x, weights): 
+            return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
+
+        weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)
+        optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)
+
+        def closure():
+            optimizer.zero_grad()
+            output = function(x_spectrum, weights)
+            loss = F.mse_loss(output, y_spectrum)
+            loss.backward()
+            return loss
+        optimizer.step(closure)
+        spectrum = function(x_spectrum, weights)
+        return spectrum.detach().to(self.device)
+
+
+    def display_spectrum_response(wavelength, function):
+        """
+        Internal function to provide light spectrum response at particular wavelength
+
+        Parameters
+        ----------
+        wavelength                          : torch.tensor
+                                              Wavelength in nm [400...700]
+        function                            : torch.tensor
+                                              Display light spectrum distribution function
+
+        Returns
+        -------
+        ligth_response_dict                  : float
+                                               Display light spectrum response value
+        """
+        wavelength = int(round(wavelength, 0))
+        if wavelength >= 400 and wavelength <= 700:
+            return function[wavelength - 400].item()
+        elif wavelength < 400:
+            return function[0].item()
+        else:
+            return function[300].item()
+
+
+    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
+        """
+        Internal function to calculate cone response at particular light spectrum. 
+
+        Parameters
+        ----------
+        cone_spectrum                         : torch.tensor
+                                                Spectrum, Wavelength [2,300] tensor 
+        light_spectrum                        : torch.tensor
+                                                Spectrum, Wavelength [2,300] tensor 
+
+
+        Returns
+        -------
+        response_to_spectrum                  : float
+                                                Response of cone to light spectrum [1x1] 
+        """
+        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
+        response_to_spectrum = torch.sum(response_to_spectrum)
+        return response_to_spectrum.item()
+
+
+    def construct_matrix_lms(self, l_response, m_response, s_response):
+        '''
+        Internal function to calculate cone  response at particular light spectrum. 
+
+        Parameters
+        ----------
+        l_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        m_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        s_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+        Returns
+        -------
+        lms_image_tensor                      : torch.tensor
+                                                3x3 LMSrgb tensor
+
+        '''
+        if self.read_spectrum == 'tensor':
+            logging.warning('Tensor primary spectrum is used')
+            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+        else:
+            logging.warning("No Spectrum data is provided")
+
+        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
+        for i in range(self.primaries_spectrum.shape[0]):
+            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
+            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
+            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
+        return self.lms_tensor    
+
+
+    def construct_matrix_primaries(self, l_response, m_response, s_response):
+        '''
+        Internal function to calculate cone  response at particular light spectrum. 
+
+        Parameters
+        ----------
+        l_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        m_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+        s_response                             : torch.tensor
+                                                 Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+        Returns
+        -------
+        lms_image_tensor                      : torch.tensor
+                                                3x3 LMSrgb tensor
+
+        '''
+        if self.read_spectrum == 'tensor':
+            logging.warning('Tensor primary spectrum is used')
+            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+        else:
+            logging.warning("No Spectrum data is provided")
+
+        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
+        for i in range(self.primaries_spectrum.shape[0]):
+            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
+                                                                         l_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        )
+            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
+                                                                         m_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        )
+            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
+                                                                         s_response,
+                                                                         self.primaries_spectrum[i]
+                                                                        ) 
+        return self.primaries_tensor    
+
+
+    def primaries_to_lms(self, primaries):
+        """
+        Internal function to convert primaries space to LMS space 
+
+        Parameters
+        ----------
+        primaries                              : torch.tensor
+                                                 Primaries data to be transformed to LMS space [BxPHxW]
+
+
+        Returns
+        -------
+        lms_color                              : torch.tensor
+                                                 LMS data transformed from Primaries space [BxPxHxW]
+        """                
+        primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)
+        lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)
+        lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)
+        return lms_color
+
+
+    def lms_to_primaries(self, lms_color_tensor):
+        """
+        Internal function to convert LMS image to primaries space
+
+        Parameters
+        ----------
+        lms_color_tensor                        : torch.tensor
+                                                  LMS data to be transformed to primaries space [Bx3xHxW]
+
+
+        Returns
+        -------
+        primaries                              : torch.tensor
+                                               : Primaries data transformed from LMS space [BxPxHxW]
+        """
+        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
+        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
+        unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
+        converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())
+        primaries = unflatten(converted_unflatten)     
+        primaries = primaries.permute(0, 3, 1, 2)   
+        return primaries
+
+
+    def second_to_third_stage(self, lms_image):
+        '''
+        This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], 
+        See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.
+
+        Parameters
+        ----------
+        lms_image                             : torch.tensor
+                                                 Image data at LMS space (second stage)
+
+        Returns
+        -------
+        third_stage                            : torch.tensor
+                                                 Image data at LMS space (third stage)
+
+        '''
+        third_stage = torch.zeros_like(lms_image)
+        third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]
+        third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]
+        third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]
+        return third_stage
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(input_image, ground_truth, gaze=None) + +

+ + +
+ +

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

+ +
+ Source code in odak/learn/perception/color_conversion.py +
55
+56
+57
+58
+59
+60
+61
+62
+63
+64
def __call__(self, input_image, ground_truth, gaze=None):
+    """
+    Evaluating an input image against a target ground truth image for a given gaze of a viewer.
+    """
+    lms_image_second = self.primaries_to_lms(input_image.to(self.device))
+    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))
+    lms_image_third = self.second_to_third_stage(lms_image_second)
+    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)
+    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)
+    return loss_metamer_color
+
+
+
+ +
+ +
+ + +

+ __init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + resolution + – +
    +
                          Resolution of the display in pixels.
    +
    +
    +
  • +
  • + distance_from_screen + – +
    +
                          Distance from the screen in mm.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
                          Pixel pitch of the display in mm.
    +
    +
    +
  • +
  • + read_spectrum + – +
    +
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
    +
    +
    +
  • +
  • + device + – +
    +
                          Device to run the code on. Default is None which means the code will run on CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def __init__(
+             self,
+             resolution = [1920, 1080],
+             distance_from_screen = 800,
+             pixel_pitch = 0.311,
+             read_spectrum = 'tensor',
+             primaries_spectrum = torch.rand(3, 301),
+             device = torch.device('cpu')):
+    '''
+    Parameters
+    ----------
+    resolution                  : list
+                                  Resolution of the display in pixels.
+    distance_from_screen        : int
+                                  Distance from the screen in mm.
+    pixel_pitch                 : float
+                                  Pixel pitch of the display in mm.
+    read_spectrum               : str
+                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.
+    device                      : torch.device
+                                  Device to run the code on. Default is None which means the code will run on CPU.
+
+    '''
+    self.device = device
+    self.read_spectrum = read_spectrum
+    self.primaries_spectrum = primaries_spectrum.to(self.device)
+    self.resolution = resolution
+    self.distance_from_screen = distance_from_screen
+    self.pixel_pitch = pixel_pitch
+    self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()
+    self.lms_tensor = self.construct_matrix_lms(
+                                                self.l_normalized,
+                                                self.m_normalized,
+                                                self.s_normalized
+                                               )   
+    self.primaries_tensor = self.construct_matrix_primaries(
+                                                            self.l_normalized,
+                                                            self.m_normalized,
+                                                            self.s_normalized
+                                                           )   
+    return
+
+
+
+ +
+ +
+ + +

+ cone_response_to_spectrum(cone_spectrum, light_spectrum) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + cone_spectrum + – +
    +
                                    Spectrum, Wavelength [2,300] tensor
    +
    +
    +
  • +
  • + light_spectrum + – +
    +
                                    Spectrum, Wavelength [2,300] tensor
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +response_to_spectrum ( float +) – +
    +

    Response of cone to light spectrum [1x1]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):
+    """
+    Internal function to calculate cone response at particular light spectrum. 
+
+    Parameters
+    ----------
+    cone_spectrum                         : torch.tensor
+                                            Spectrum, Wavelength [2,300] tensor 
+    light_spectrum                        : torch.tensor
+                                            Spectrum, Wavelength [2,300] tensor 
+
+
+    Returns
+    -------
+    response_to_spectrum                  : float
+                                            Response of cone to light spectrum [1x1] 
+    """
+    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)
+    response_to_spectrum = torch.sum(response_to_spectrum)
+    return response_to_spectrum.item()
+
+
+
+ +
+ +
+ + +

+ construct_matrix_lms(l_response, m_response, s_response) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + l_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + m_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + s_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_image_tensor ( tensor +) – +
    +

    3x3 LMSrgb tensor

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def construct_matrix_lms(self, l_response, m_response, s_response):
+    '''
+    Internal function to calculate cone  response at particular light spectrum. 
+
+    Parameters
+    ----------
+    l_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    m_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    s_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+    Returns
+    -------
+    lms_image_tensor                      : torch.tensor
+                                            3x3 LMSrgb tensor
+
+    '''
+    if self.read_spectrum == 'tensor':
+        logging.warning('Tensor primary spectrum is used')
+        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+    else:
+        logging.warning("No Spectrum data is provided")
+
+    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)
+    for i in range(self.primaries_spectrum.shape[0]):
+        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])
+        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])
+        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) 
+    return self.lms_tensor    
+
+
+
+ +
+ +
+ + +

+ construct_matrix_primaries(l_response, m_response, s_response) + +

+ + +
+ +

Internal function to calculate cone response at particular light spectrum.

+ + +

Parameters:

+
    +
  • + l_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + m_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
  • + s_response + – +
    +
                                     Cone response spectrum tensor (normalized response vs wavelength)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_image_tensor ( tensor +) – +
    +

    3x3 LMSrgb tensor

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def construct_matrix_primaries(self, l_response, m_response, s_response):
+    '''
+    Internal function to calculate cone  response at particular light spectrum. 
+
+    Parameters
+    ----------
+    l_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    m_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+    s_response                             : torch.tensor
+                                             Cone response spectrum tensor (normalized response vs wavelength)
+
+
+
+    Returns
+    -------
+    lms_image_tensor                      : torch.tensor
+                                            3x3 LMSrgb tensor
+
+    '''
+    if self.read_spectrum == 'tensor':
+        logging.warning('Tensor primary spectrum is used')
+        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))
+    else:
+        logging.warning("No Spectrum data is provided")
+
+    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)
+    for i in range(self.primaries_spectrum.shape[0]):
+        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(
+                                                                     l_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    )
+        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(
+                                                                     m_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    )
+        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(
+                                                                     s_response,
+                                                                     self.primaries_spectrum[i]
+                                                                    ) 
+    return self.primaries_tensor    
+
+
+
+ +
+ +
+ + +

+ display_spectrum_response(wavelength, function) + +

+ + +
+ +

Internal function to provide light spectrum response at particular wavelength

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
                                  Wavelength in nm [400...700]
    +
    +
    +
  • +
  • + function + – +
    +
                                  Display light spectrum distribution function
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ligth_response_dict ( float +) – +
    +

    Display light spectrum response value

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def display_spectrum_response(wavelength, function):
+    """
+    Internal function to provide light spectrum response at particular wavelength
+
+    Parameters
+    ----------
+    wavelength                          : torch.tensor
+                                          Wavelength in nm [400...700]
+    function                            : torch.tensor
+                                          Display light spectrum distribution function
+
+    Returns
+    -------
+    ligth_response_dict                  : float
+                                           Display light spectrum response value
+    """
+    wavelength = int(round(wavelength, 0))
+    if wavelength >= 400 and wavelength <= 700:
+        return function[wavelength - 400].item()
+    elif wavelength < 400:
+        return function[0].item()
+    else:
+        return function[300].item()
+
+
+
+ +
+ +
+ + +

+ initialize_cones_normalized() + +

+ + +
+ +

Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values.

+ + +

Returns:

+
    +
  • +l_cone_n ( tensor +) – +
    +

    Normalised L cone distribution.

    +
    +
  • +
  • +m_cone_n ( tensor +) – +
    +

    Normalised M cone distribution.

    +
    +
  • +
  • +s_cone_n ( tensor +) – +
    +

    Normalised S cone distribution.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
def initialize_cones_normalized(self):
+    """
+    Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. 
+
+    Returns
+    -------
+    l_cone_n                     : torch.tensor
+                                   Normalised L cone distribution.
+    m_cone_n                     : torch.tensor
+                                   Normalised M cone distribution.
+    s_cone_n                     : torch.tensor
+                                   Normalised S cone distribution.
+    """
+    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+    dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))
+    dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))
+    dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))
+
+    l_cone_n = dist_l / dist_l.max()
+    m_cone_n = dist_m / dist_m.max()
+    s_cone_n = dist_s / dist_s.max()
+    return l_cone_n, m_cone_n, s_cone_n
+
+
+
+ +
+ +
+ + +

+ initialize_random_spectrum_normalized(dataset) + +

+ + +
+ +

Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

+ + +

Parameters:

+
    +
  • + dataset + – +
    +
                                     spectrum value against wavelength
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def initialize_random_spectrum_normalized(self, dataset):
+    """
+    Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. 
+
+    Parameters
+    ----------
+    dataset                                : torch.tensor 
+                                             spectrum value against wavelength 
+    """
+    dataset = torch.swapaxes(dataset, 0, 1)
+    x_spectrum = torch.linspace(400, 700, steps = 301) - 550
+    y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))
+    max_spectrum = torch.max(y_spectrum)
+    y_spectrum /= max_spectrum
+
+    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \
+        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))
+
+    def function(x, weights): 
+        return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])
+
+    weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)
+    optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)
+
+    def closure():
+        optimizer.zero_grad()
+        output = function(x_spectrum, weights)
+        loss = F.mse_loss(output, y_spectrum)
+        loss.backward()
+        return loss
+    optimizer.step(closure)
+    spectrum = function(x_spectrum, weights)
+    return spectrum.detach().to(self.device)
+
+
+
+ +
+ +
+ + +

+ initialize_rgb_backlight_spectrum() + +

+ + +
+ +

Internal function to initialize baclight spectrum for color primaries.

+ + +

Returns:

+
    +
  • +red_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for red color primary.

    +
    +
  • +
  • +green_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for green color primary.

    +
    +
  • +
  • +blue_spectrum ( tensor +) – +
    +

    Normalised backlight spectrum for blue color primary.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def initialize_rgb_backlight_spectrum(self):
+    """
+    Internal function to initialize baclight spectrum for color primaries. 
+
+    Returns
+    -------
+    red_spectrum                 : torch.tensor
+                                   Normalised backlight spectrum for red color primary.
+    green_spectrum               : torch.tensor
+                                   Normalised backlight spectrum for green color primary.
+    blue_spectrum                : torch.tensor
+                                   Normalised backlight spectrum for blue color primary.
+    """
+    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)
+    red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))
+    green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))
+    blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))
+
+    red_spectrum = red_spectrum / red_spectrum.max()
+    green_spectrum = green_spectrum / green_spectrum.max()
+    blue_spectrum = blue_spectrum / blue_spectrum.max()
+
+    return red_spectrum, green_spectrum, blue_spectrum
+
+
+
+ +
+ +
+ + +

+ lms_to_primaries(lms_color_tensor) + +

+ + +
+ +

Internal function to convert LMS image to primaries space

+ + +

Parameters:

+
    +
  • + lms_color_tensor + – +
    +
                                      LMS data to be transformed to primaries space [Bx3xHxW]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +primaries ( tensor +) – +
    +

    : Primaries data transformed from LMS space [BxPxHxW]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def lms_to_primaries(self, lms_color_tensor):
+    """
+    Internal function to convert LMS image to primaries space
+
+    Parameters
+    ----------
+    lms_color_tensor                        : torch.tensor
+                                              LMS data to be transformed to primaries space [Bx3xHxW]
+
+
+    Returns
+    -------
+    primaries                              : torch.tensor
+                                           : Primaries data transformed from LMS space [BxPxHxW]
+    """
+    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)
+    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)
+    unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))
+    converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())
+    primaries = unflatten(converted_unflatten)     
+    primaries = primaries.permute(0, 3, 1, 2)   
+    return primaries
+
+
+
+ +
+ +
+ + +

+ primaries_to_lms(primaries) + +

+ + +
+ +

Internal function to convert primaries space to LMS space

+ + +

Parameters:

+
    +
  • + primaries + – +
    +
                                     Primaries data to be transformed to LMS space [BxPHxW]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +lms_color ( tensor +) – +
    +

    LMS data transformed from Primaries space [BxPxHxW]

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def primaries_to_lms(self, primaries):
+    """
+    Internal function to convert primaries space to LMS space 
+
+    Parameters
+    ----------
+    primaries                              : torch.tensor
+                                             Primaries data to be transformed to LMS space [BxPHxW]
+
+
+    Returns
+    -------
+    lms_color                              : torch.tensor
+                                             LMS data transformed from Primaries space [BxPxHxW]
+    """                
+    primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)
+    lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)
+    lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)
+    return lms_color
+
+
+
+ +
+ +
+ + +

+ second_to_third_stage(lms_image) + +

+ + +
+ +

This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], +See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.

+ + +

Parameters:

+
    +
  • + lms_image + – +
    +
                                     Image data at LMS space (second stage)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +third_stage ( tensor +) – +
    +

    Image data at LMS space (third stage)

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def second_to_third_stage(self, lms_image):
+    '''
+    This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], 
+    See table 1 from Schmidt et al. "Neurobiological hypothesis of color appearance and hue perception," Optics Express 2014.
+
+    Parameters
+    ----------
+    lms_image                             : torch.tensor
+                                             Image data at LMS space (second stage)
+
+    Returns
+    -------
+    third_stage                            : torch.tensor
+                                             Image data at LMS space (third stage)
+
+    '''
+    third_stage = torch.zeros_like(lms_image)
+    third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]
+    third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]
+    third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]
+    return third_stage
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ color_map(input_image, target_image, model='Lab Stats') + +

+ + +
+ +

Internal function to map the color of an image to another image. +Reference: Color transfer between images, Reinhard et al., 2001.

+ + +

Parameters:

+
    +
  • + input_image + – +
    +
                  Input image in RGB color space [3 x m x n].
    +
    +
    +
  • +
  • + target_image + – +
    + +
    +
  • +
+ + +

Returns:

+
    +
  • +mapped_image ( Tensor +) – +
    +

    Input image with the color the distribution of the target image [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def color_map(input_image, target_image, model = 'Lab Stats'):
+    """
+    Internal function to map the color of an image to another image.
+    Reference: Color transfer between images, Reinhard et al., 2001.
+
+    Parameters
+    ----------
+    input_image         : torch.Tensor
+                          Input image in RGB color space [3 x m x n].
+    target_image        : torch.Tensor
+
+    Returns
+    -------
+    mapped_image           : torch.Tensor
+                             Input image with the color the distribution of the target image [3 x m x n].
+    """
+    if model == 'Lab Stats':
+        lab_input = srgb_to_lab(input_image)
+        lab_target = srgb_to_lab(target_image)
+        input_mean_L = torch.mean(lab_input[0, :, :])
+        input_mean_a = torch.mean(lab_input[1, :, :])
+        input_mean_b = torch.mean(lab_input[2, :, :])
+        input_std_L = torch.std(lab_input[0, :, :])
+        input_std_a = torch.std(lab_input[1, :, :])
+        input_std_b = torch.std(lab_input[2, :, :])
+        target_mean_L = torch.mean(lab_target[0, :, :])
+        target_mean_a = torch.mean(lab_target[1, :, :])
+        target_mean_b = torch.mean(lab_target[2, :, :])
+        target_std_L = torch.std(lab_target[0, :, :])
+        target_std_a = torch.std(lab_target[1, :, :])
+        target_std_b = torch.std(lab_target[2, :, :])
+        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L
+        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a
+        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b
+        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))
+        return mapped_image
+
+
+
+ +
+ +
+ + +

+ hsv_to_rgb(image) + +

+ + +
+ +

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_rgb ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def hsv_to_rgb(image):
+
+    """
+    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_rgb       : torch.tensor
+                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    h = image[..., 0, :, :] / (2 * math.pi)
+    s = image[..., 1, :, :]
+    v = image[..., 2, :, :]
+    hi = torch.floor(h * 6) % 6
+    f = ((h * 6) % 6) - hi
+    one = torch.tensor(1.0)
+    p = v * (one - s)
+    q = v * (one - f * s)
+    t = v * (one - (one - f) * s)
+    hi = hi.long()
+    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)
+    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
+    image_rgb = torch.gather(image_rgb, -3, indices)
+    return image_rgb
+
+
+
+ +
+ +
+ + +

+ lab_to_srgb(image) + +

+ + +
+ +

Definition to convert LAB space to SRGB color space.

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in LAB color space[3 x m x n]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_srgb ( tensor +) – +
    +

    Output image in SRGB color space [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def lab_to_srgb(image):
+    """
+    Definition to convert LAB space to SRGB color space. 
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in LAB color space[3 x m x n]
+    Returns
+    -------
+    image_srgb     : torch.tensor
+                      Output image in SRGB color space [3 x m x n].
+    """
+
+    if image.shape[-1] == 3:
+        input_color = image.permute(2, 0, 1)  # C(H*W)
+    else:
+        input_color = image
+    # lab ---> xyz
+    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)
+    y = (input_color[0:1, :, :] + 16) / 116
+    a =  input_color[1:2, :, :] / 500
+    b =  input_color[2:3, :, :] / 200
+    x = y + a
+    z = y - b
+    xyz = torch.cat((x, y, z), 0)
+    delta = 6 / 29
+    factor = 3 * delta * delta
+    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))
+    xyz_color = xyz * reference_illuminant
+    # xyz ---> linear rgb
+    a11 = 3.241003275
+    a12 = -1.537398934
+    a13 = -0.498615861
+    a21 = -0.969224334
+    a22 = 1.875930071
+    a23 = 0.041554224
+    a31 = 0.055639423
+    a32 = -0.204011202
+    a33 = 1.057148933
+    A = torch.tensor([[a11, a12, a13],
+                  [a21, a22, a23],
+                  [a31, a32, a33]], dtype=torch.float32)
+
+    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)
+    linear_rgb_color = torch.matmul(A, xyz_color)
+    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)
+    # linear rgb ---> srgb
+    limit = 0.0031308
+    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)
+    return image_srgb
+
+
+
+ +
+ +
+ + +

+ linear_rgb_to_rgb(image, threshold=0.0031308) + +

+ + +
+ +

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
  • + threshold + – +
    +
              Threshold used in calculations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def linear_rgb_to_rgb(image, threshold = 0.0031308):
+    """
+    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+    threshold       : float
+                      Threshold used in calculations.
+
+    Returns
+    -------
+    image_linear    : torch.tensor
+                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)
+    return image_linear
+
+
+
+ +
+ +
+ + +

+ linear_rgb_to_xyz(image) + +

+ + +
+ +

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_xyz ( tensor +) – +
    +

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def linear_rgb_to_xyz(image):
+    """
+    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_xyz       : torch.tensor
+                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    a11 = 0.412453
+    a12 = 0.357580
+    a13 = 0.180423
+    a21 = 0.212671
+    a22 = 0.715160
+    a23 = 0.072169
+    a31 = 0.019334
+    a32 = 0.119193
+    a33 = 0.950227
+    M = torch.tensor([[a11, a12, a13], 
+                      [a21, a22, a23],
+                      [a31, a32, a33]])
+    size = image.size()
+    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)
+    image_xyz = torch.matmul(M, image)
+    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])
+    return image_xyz
+
+
+
+ +
+ +
+ + +

+ rgb_2_ycrcb(image) + +

+ + +
+ +

Converts an image from RGB colourspace to YCrCb colourspace.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ycrcb ( tensor +) – +
    +

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_2_ycrcb(image):
+    """
+    Converts an image from RGB colourspace to YCrCb colourspace.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
+
+    Returns
+    -------
+
+    ycrcb   : torch.tensor
+              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+       image = image.unsqueeze(0)
+    ycrcb = torch.zeros(image.size()).to(image.device)
+    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \
+        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]
+    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])
+    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])
+    return ycrcb
+
+
+
+ +
+ +
+ + +

+ rgb_to_hsv(image, eps=1e-08) + +

+ + +
+ +

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_hsv ( tensor +) – +
    +

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_to_hsv(image, eps: float = 1e-8):
+
+    """
+    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_hsv       : torch.tensor
+                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    max_rgb, argmax_rgb = image.max(-3)
+    min_rgb, argmin_rgb = image.min(-3)
+    deltac = max_rgb - min_rgb
+    v = max_rgb
+    s = deltac / (max_rgb + eps)
+    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
+    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
+    h1 = bc - gc
+    h2 = (rc - bc) + 2.0 * deltac
+    h3 = (gc - rc) + 4.0 * deltac
+    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
+    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
+    h = (h / 6.0) % 1.0
+    h = 2.0 * math.pi * h 
+    image_hsv = torch.stack((h, s, v), dim=-3)
+    return image_hsv
+
+
+
+ +
+ +
+ + +

+ rgb_to_linear_rgb(image, threshold=0.0031308) + +

+ + +
+ +

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
  • + threshold + – +
    +
              Threshold used in calculations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear ( tensor +) – +
    +

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def rgb_to_linear_rgb(image, threshold = 0.0031308):
+    """
+    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+    threshold       : float
+                      Threshold used in calculations.
+
+    Returns
+    -------
+    image_linear    : torch.tensor
+                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)
+    return image_linear
+
+
+
+ +
+ +
+ + +

+ srgb_to_lab(image) + +

+ + +
+ +

Definition to convert SRGB space to LAB color space.

+ + +

Parameters:

+
    +
  • + image + – +
    +
              Input image in SRGB color space[3 x m x n]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_lab ( tensor +) – +
    +

    Output image in LAB color space [3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def srgb_to_lab(image):    
+    """
+    Definition to convert SRGB space to LAB color space. 
+
+    Parameters
+    ----------
+    image           : torch.tensor
+                      Input image in SRGB color space[3 x m x n]
+    Returns
+    -------
+    image_lab       : torch.tensor
+                      Output image in LAB color space [3 x m x n].
+    """
+    if image.shape[-1] == 3:
+        input_color = image.permute(2, 0, 1)  # C(H*W)
+    else:
+        input_color = image
+    # rgb ---> linear rgb
+    limit = 0.04045        
+    # linear rgb ---> xyz
+    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)
+
+    a11 = 10135552 / 24577794
+    a12 = 8788810  / 24577794
+    a13 = 4435075  / 24577794
+    a21 = 2613072  / 12288897
+    a22 = 8788810  / 12288897
+    a23 = 887015   / 12288897
+    a31 = 1425312  / 73733382
+    a32 = 8788810  / 73733382
+    a33 = 70074185 / 73733382
+
+    A = torch.tensor([[a11, a12, a13],
+                    [a21, a22, a23],
+                    [a31, a32, a33]], dtype=torch.float32)
+
+    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)
+    xyz_color = torch.matmul(A, linrgb_color)
+    xyz_color = xyz_color.permute(1, 2, 0)
+    # xyz ---> lab
+    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)
+    input_color = xyz_color * inv_reference_illuminant
+    delta = 6 / 29
+    delta_square = delta * delta
+    delta_cube = delta * delta_square
+    factor = 1 / (3 * delta_square)
+
+    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))
+
+    l = 116 * input_color[1:2, :, :] - 16
+    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])
+    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])
+
+    image_lab = torch.cat((l, a, b), 0)
+    return image_lab    
+
+
+
+ +
+ +
+ + +

+ xyz_to_linear_rgb(image) + +

+ + +
+ +

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

+ + +

Parameters:

+
    +
  • + image + – +
    +
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image_linear_rgb ( tensor +) – +
    +

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def xyz_to_linear_rgb(image):
+    """
+    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)
+
+    Parameters
+    ----------
+    image            : torch.tensor
+                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.
+
+    Returns
+    -------
+    image_linear_rgb : torch.tensor
+                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    a11 = 3.240479
+    a12 = -1.537150
+    a13 = -0.498535
+    a21 = -0.969256 
+    a22 = 1.875992 
+    a23 = 0.041556
+    a31 = 0.055648
+    a32 = -0.204043
+    a33 = 1.057311
+    M = torch.tensor([[a11, a12, a13], 
+                      [a21, a22, a23],
+                      [a31, a32, a33]])
+    size = image.size()
+    image = image.reshape(size[0], size[1], size[2]*size[3])
+    image_linear_rgb = torch.matmul(M, image)
+    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])
+    return image_linear_rgb
+
+
+
+ +
+ +
+ + +

+ ycrcb_2_rgb(image) + +

+ + +
+ +

Converts an image from YCrCb colourspace to RGB colourspace.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rgb ( tensor +) – +
    +

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/color_conversion.py +
def ycrcb_2_rgb(image):
+    """
+    Converts an image from YCrCb colourspace to RGB colourspace.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].
+
+    Returns
+    -------
+    rgb     : torch.tensor
+              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].
+    """
+    if len(image.shape) == 3:
+       image = image.unsqueeze(0)
+    rgb = torch.zeros(image.size(), device=image.device)
+    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)
+    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \
+        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)
+    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)
+    return rgb
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6) + +

+ + +
+ +

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to +a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction.

+ + +

Parameters:

+
    +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +map ( tensor +) – +
    +

    The computed 3D location map, of size 3xWxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
+    """ 
+    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to
+    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction.
+
+    Parameters
+    ----------
+
+    image_pixel_size        : tuple of ints 
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float 
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    map                     : torch.tensor
+                                The computed 3D location map, of size 3xWxH.
+    """
+    real_image_height = (real_image_width /
+                         image_pixel_size[-1]) * image_pixel_size[-2]
+    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width
+    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)
+    y_coords = torch.linspace(-0.5, 0.5,
+                              image_pixel_size[-2])*real_image_height
+    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])
+    z_coords = torch.ones(
+        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance
+
+    return torch.cat([x_coords, y_coords, z_coords], dim=0)
+
+
+
+ +
+ +
+ + +

+ make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6) + +

+ + +
+ +

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to +a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction. Output in radians.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +eccentricity_map ( tensor +) – +
    +

    The computed eccentricity map, of size WxH.

    +
    +
  • +
  • +distance_map ( tensor +) – +
    +

    The computed distance map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):
+    """ 
+    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to
+    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction. Output in radians.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    eccentricity_map        : torch.tensor
+                                The computed eccentricity map, of size WxH.
+    distance_map            : torch.tensor
+                                The computed distance map, of size WxH.
+    """
+    real_image_height = (real_image_width /
+                         image_pixel_size[-1]) * image_pixel_size[-2]
+    location_map = make_3d_location_map(
+        image_pixel_size, real_image_width, real_viewing_distance)
+    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))
+    direction_map = location_map / distance_map
+
+    gaze_location_3d = torch.tensor([
+        (gaze_location[0]*2 - 1)*real_image_width*0.5,
+        (gaze_location[1]*2 - 1)*real_image_height*0.5,
+        real_viewing_distance])
+    gaze_dir = gaze_location_3d / \
+        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))
+    gaze_dir = gaze_dir[:, None, None]
+
+    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)
+    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)
+    eccentricity_map = torch.acos(dot_prod_map)
+
+    return eccentricity_map, distance_map
+
+
+
+ +
+ +
+ + +

+ make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic') + +

+ + +
+ +

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from +to achieve the correct pooling region areas.

+ + +

Parameters:

+
    +
  • + gaze_angles + – +
    +
                    Gaze direction expressed as angles, in radians.
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                    Dimensions of the image in pixels, as a tuple of (height, width)
    +
    +
    +
  • +
  • + alpha + – +
    +
                    Parameter controlling extent of foveation
    +
    +
    +
  • +
  • + mode + – +
    +
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size HxW.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
+    """ 
+    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
+    to achieve the correct pooling region areas.
+
+    Parameters
+    ----------
+
+    gaze_angles         : tuple of 2 floats
+                            Gaze direction expressed as angles, in radians.
+    image_pixel_size    : tuple of 2 ints
+                            Dimensions of the image in pixels, as a tuple of (height, width)
+    alpha               : float
+                            Parameter controlling extent of foveation
+    mode                : str
+                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size HxW.
+    """
+    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)
+    import matplotlib.pyplot as plt
+    pooling_lod = torch.log2(1e-6+pooling_pixel)
+    pooling_lod[pooling_lod < 0] = 0
+    return pooling_lod
+
+
+
+ +
+ +
+ + +

+ make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic') + +

+ + +
+ +

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. +Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, +the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image +corresponds to pitch, ranging from -pi/2 to pi/2.

+

In this setup real_image_width and real_viewing_distance have no effect.

+

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

+ + +

Parameters:

+
    +
  • + gaze_angles + – +
    +
                    Gaze direction expressed as angles, in radians.
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                    Dimensions of the image in pixels, as a tuple of (height, width)
    +
    +
    +
  • +
  • + alpha + – +
    +
                    Parameter controlling extent of foveation
    +
    +
    +
  • +
  • + mode + – +
    +
                    Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode="quadratic"):
+    """
+    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.
+    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, 
+    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image
+    corresponds to pitch, ranging from -pi/2 to pi/2.
+
+    In this setup real_image_width and real_viewing_distance have no effect.
+
+    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).
+
+    Parameters
+    ----------
+
+    gaze_angles         : tuple of 2 floats
+                            Gaze direction expressed as angles, in radians.
+    image_pixel_size    : tuple of 2 ints
+                            Dimensions of the image in pixels, as a tuple of (height, width)
+    alpha               : float
+                            Parameter controlling extent of foveation
+    mode                : str
+                            Foveation mode (how pooling size varies with eccentricity). Should be "quadratic" or "linear"
+    """
+    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])
+
+    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])
+    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]
+    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])
+    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]
+
+    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])
+
+    # Work out the pooling region diameter in radians
+    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)
+    eccentricity = torch.acos(view_dot_dir)
+    pooling_rad = alpha * eccentricity
+    if mode == "quadratic":
+        pooling_rad *= eccentricity
+
+    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes
+    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse
+    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)
+    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)
+    pooling_axis_x = pooling_rad * pixels_per_rad_x
+    pooling_axis_y = pooling_rad * pixels_per_rad_y
+    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25
+
+    # Now finally find the length of the side of a square of the same area.
+    size = torch.sqrt(torch.abs(area))
+    return size
+
+
+
+ +
+ +
+ + +

+ make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic') + +

+ + +
+ +

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from +to achieve the correct pooling region areas.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
+    """ 
+    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from
+    to achieve the correct pooling region areas.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size WxH.
+    """
+    pooling_pixel = make_pooling_size_map_pixels(
+        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)
+    pooling_lod = torch.log2(1e-6+pooling_pixel)
+    pooling_lod[pooling_lod < 0] = 0
+    return pooling_lod
+
+
+
+ +
+ +
+ + +

+ make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic') + +

+ + +
+ +

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to +a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity +(also in radians).

+

Assumes the viewpoint is located at the centre of the image, and the screen is +perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

+ + +

Parameters:

+
    +
  • + gaze_location + – +
    +
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +                    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
  • + image_pixel_size + – +
    +
                        The size of the image in pixels, as a tuple of form (height, width)
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed. Units not important, as long as they
    +                    are the same as those used for real_viewing_distance
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance from the user's viewpoint to the screen.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pooling_size_map ( tensor +) – +
    +

    The computed pooling size map, of size WxH.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode="quadratic"):
+    """ 
+    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to
+    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity
+    (also in radians). 
+
+    Assumes the viewpoint is located at the centre of the image, and the screen is 
+    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.
+
+    Parameters
+    ----------
+
+    gaze_location           : tuple of floats
+                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                                image coordinates (ranging from 0 to 1)
+    image_pixel_size        : tuple of ints
+                                The size of the image in pixels, as a tuple of form (height, width)
+    real_image_width        : float
+                                The real width of the image as displayed. Units not important, as long as they
+                                are the same as those used for real_viewing_distance
+    real_viewing_distance   : float
+                                The real distance from the user's viewpoint to the screen.
+
+    Returns
+    -------
+
+    pooling_size_map        : torch.tensor
+                                The computed pooling size map, of size WxH.
+    """
+    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(
+        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)
+    eccentricity_centre, _ = make_eccentricity_distance_maps(
+        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)
+    pooling_rad = alpha * eccentricity
+    if mode == "quadratic":
+        pooling_rad *= eccentricity
+    angle_min = eccentricity_centre - pooling_rad*0.5
+    angle_max = eccentricity_centre + pooling_rad*0.5
+    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \
+        real_viewing_distance
+    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)
+    area = math.pi * major_axis * minor_axis * 0.25
+    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number
+    area = torch.abs(area)
+    pooling_real = torch.sqrt(area)
+    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]
+    return pooling_pixel
+
+
+
+ +
+ +
+ + +

+ make_radial_map(size, gaze) + +

+ + +
+ +

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

+ + +

Parameters:

+
    +
  • + size + – +
    +
        Dimensions of the image
    +
    +
    +
  • +
  • + gaze + – +
    +
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized
    +    image coordinates (ranging from 0 to 1)
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/foveation.py +
def make_radial_map(size, gaze):
+    """ 
+    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.
+
+    Parameters
+    ----------
+
+    size    : tuple of ints
+                Dimensions of the image
+    gaze    : tuple of floats
+                User's gaze (fixation point) in the image. Should be given as a tuple with normalized
+                image coordinates (ranging from 0 to 1)
+    """
+    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]
+    rows = torch.linspace(0, size[0], size[0])
+    rows = rows[:, None].repeat(1, size[1])
+    cols = torch.linspace(0, size[1], size[1])
+    cols = cols[None, :].repeat(size[0], 1)
+    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \
+        torch.pow(cols - pix_gaze[1], 2)
+    radii = torch.sqrt(dist_sq)
+    return radii/torch.max(radii)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ MSSSIM + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
class MSSSIM(nn.Module):
+    '''
+    A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(MSSSIM, self).__init__()
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions : torch.tensor
+                      The predicted images.
+        targets     : torch.tensor
+                      The ground truth images.
+
+        Returns
+        -------
+        result      : torch.tensor 
+                      The computed MS-SSIM value if successful, otherwise 0.0.
+        """
+        try:
+            from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
+            if len(predictions.shape) == 3:
+                predictions = predictions.unsqueeze(0)
+                targets = targets.unsqueeze(0)
+            l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
+            return l_MSSSIM  
+        except Exception as e:
+            logging.warning('MS-SSIM failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + (tensor) + – +
    +
          The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
          The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed MS-SSIM value if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions : torch.tensor
+                  The predicted images.
+    targets     : torch.tensor
+                  The ground truth images.
+
+    Returns
+    -------
+    result      : torch.tensor 
+                  The computed MS-SSIM value if successful, otherwise 0.0.
+    """
+    try:
+        from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
+        if len(predictions.shape) == 3:
+            predictions = predictions.unsqueeze(0)
+            targets = targets.unsqueeze(0)
+        l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)
+        return l_MSSSIM  
+    except Exception as e:
+        logging.warning('MS-SSIM failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ PSNR + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
class PSNR(nn.Module):
+    '''
+    A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(PSNR, self).__init__()
+
+    def forward(self, predictions, targets, peak_value = 1.0):
+        """
+        A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        Image to be tested.
+        targets       : torch.tensor
+                        Ground truth image.
+        peak_value    : float
+                        Peak value that given tensors could have.
+
+        Returns
+        -------
+        result        : torch.tensor
+                        Peak-signal-to-noise ratio.
+        """
+        mse = torch.mean((targets - predictions) ** 2)
+        result = 20 * torch.log10(peak_value / torch.sqrt(mse))
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets, peak_value=1.0) + +

+ + +
+ +

A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

+ + +

Parameters:

+
    +
  • + predictions + – +
    +
            Image to be tested.
    +
    +
    +
  • +
  • + targets + – +
    +
            Ground truth image.
    +
    +
    +
  • +
  • + peak_value + – +
    +
            Peak value that given tensors could have.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Peak-signal-to-noise ratio.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
def forward(self, predictions, targets, peak_value = 1.0):
+    """
+    A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.
+
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    Image to be tested.
+    targets       : torch.tensor
+                    Ground truth image.
+    peak_value    : float
+                    Peak value that given tensors could have.
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Peak-signal-to-noise ratio.
+    """
+    mse = torch.mean((targets - predictions) ** 2)
+    result = 20 * torch.log10(peak_value / torch.sqrt(mse))
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SSIM + + +

+ + +
+

+ Bases: Module

+ + +

A class to calculate structural similarity index of an image with respect to a ground truth image.

+ + + + + + +
+ Source code in odak/learn/perception/image_quality_losses.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
class SSIM(nn.Module):
+    '''
+    A class to calculate structural similarity index of an image with respect to a ground truth image.
+    '''
+
+    def __init__(self):
+        super(SSIM, self).__init__()
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions : torch.tensor
+                      The predicted images.
+        targets     : torch.tensor
+                      The ground truth images.
+
+        Returns
+        -------
+        result      : torch.tensor 
+                      The computed SSIM value if successful, otherwise 0.0.
+        """
+        try:
+            from torchmetrics.functional.image import structural_similarity_index_measure
+            if len(predictions.shape) == 3:
+                predictions = predictions.unsqueeze(0)
+                targets = targets.unsqueeze(0)
+            l_SSIM = structural_similarity_index_measure(predictions, targets)
+            return l_SSIM
+        except Exception as e:
+            logging.warning('SSIM failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + (tensor) + – +
    +
          The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
          The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed SSIM value if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/image_quality_losses.py +
45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions : torch.tensor
+                  The predicted images.
+    targets     : torch.tensor
+                  The ground truth images.
+
+    Returns
+    -------
+    result      : torch.tensor 
+                  The computed SSIM value if successful, otherwise 0.0.
+    """
+    try:
+        from torchmetrics.functional.image import structural_similarity_index_measure
+        if len(predictions.shape) == 3:
+            predictions = predictions.unsqueeze(0)
+            targets = targets.unsqueeze(0)
+        l_SSIM = structural_similarity_index_measure(predictions, targets)
+        return l_SSIM
+    except Exception as e:
+        logging.warning('SSIM failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ CVVDP + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
class CVVDP(nn.Module):
+    def __init__(self, device = torch.device('cpu')):
+        """
+        Initializes the CVVDP model with a specified device.
+
+        Parameters
+        ----------
+        device   : torch.device
+                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+        """
+        super(CVVDP, self).__init__()
+        try:
+            import pycvvdp
+            self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
+        except Exception as e:
+            logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets, dim_order = 'CHW'):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets    h  : torch.tensor
+                        The ground truth images.
+        dim_order     : str
+                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+        Returns
+        -------
+        result        : torch.tensor
+                        The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
+            return l_ColorVideoVDP
+        except Exception as e:
+            logging.warning('ColorVideoVDP failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(device=torch.device('cpu')) + +

+ + +
+ +

Initializes the CVVDP model with a specified device.

+ + +

Parameters:

+
    +
  • + device + – +
    +
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
def __init__(self, device = torch.device('cpu')):
+    """
+    Initializes the CVVDP model with a specified device.
+
+    Parameters
+    ----------
+    device   : torch.device
+                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+    """
+    super(CVVDP, self).__init__()
+    try:
+        import pycvvdp
+        self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)
+    except Exception as e:
+        logging.warning('ColorVideoVDP is missing, consider installing by running "pip install -U git+https://github.com/gfxdisp/ColorVideoVDP"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets, dim_order='CHW') + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
  • + dim_order + – +
    +
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def forward(self, predictions, targets, dim_order = 'CHW'):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets    h  : torch.tensor
+                    The ground truth images.
+    dim_order     : str
+                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+    Returns
+    -------
+    result        : torch.tensor
+                    The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)
+        return l_ColorVideoVDP
+    except Exception as e:
+        logging.warning('ColorVideoVDP failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ FVVDP + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
class FVVDP(nn.Module):
+    def __init__(self, device = torch.device('cpu')):
+        """
+        Initializes the FVVDP model with a specified device.
+
+        Parameters
+        ----------
+        device   : torch.device
+                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+        """
+        super(FVVDP, self).__init__()
+        try:
+            import pyfvvdp
+            self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
+        except Exception as e:
+            logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets, dim_order = 'CHW'):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets       : torch.tensor
+                        The ground truth images.
+        dim_order     : str
+                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+        Returns
+        -------
+        result        : torch.tensor
+                          The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
+            return l_FovVideoVDP
+        except Exception as e:
+            logging.warning('FovVideoVDP failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(device=torch.device('cpu')) + +

+ + +
+ +

Initializes the FVVDP model with a specified device.

+ + +

Parameters:

+
    +
  • + device + – +
    +
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
def __init__(self, device = torch.device('cpu')):
+    """
+    Initializes the FVVDP model with a specified device.
+
+    Parameters
+    ----------
+    device   : torch.device
+                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.
+    """
+    super(FVVDP, self).__init__()
+    try:
+        import pyfvvdp
+        self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)
+    except Exception as e:
+        logging.warning('FovVideoVDP is missing, consider installing by running "pip install pyfvvdp"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets, dim_order='CHW') + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
  • + dim_order + – +
    +
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
def forward(self, predictions, targets, dim_order = 'CHW'):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets       : torch.tensor
+                    The ground truth images.
+    dim_order     : str
+                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).
+
+    Returns
+    -------
+    result        : torch.tensor
+                      The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]
+        return l_FovVideoVDP
+    except Exception as e:
+        logging.warning('FovVideoVDP failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ LPIPS + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
class LPIPS(nn.Module):
+
+    def __init__(self):
+        """
+        Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.
+
+        """
+        super(LPIPS, self).__init__()
+        try:
+            import torchmetrics
+            self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
+        except Exception as e:
+            logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
+            logging.warning(e)
+
+
+    def forward(self, predictions, targets):
+        """
+        Parameters
+        ----------
+        predictions   : torch.tensor
+                        The predicted images.
+        targets       : torch.tensor
+                        The ground truth images.
+
+        Returns
+        -------
+        result        : torch.tensor
+                        The computed loss if successful, otherwise 0.0.
+        """
+        try:
+            lpips_image = predictions
+            lpips_target = targets
+            if len(lpips_image.shape) == 3:
+                lpips_image = lpips_image.unsqueeze(0)
+                lpips_target = lpips_target.unsqueeze(0)
+            if lpips_image.shape[1] == 1:
+                lpips_image = lpips_image.repeat(1, 3, 1, 1)
+                lpips_target = lpips_target.repeat(1, 3, 1, 1)
+            lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
+            lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
+            l_LPIPS = self.lpips(lpips_image, lpips_target)
+            return l_LPIPS
+        except Exception as e:
+            logging.warning('LPIPS failed to compute.')
+            logging.warning(e)
+            return torch.tensor(0.0)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
def __init__(self):
+    """
+    Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.
+
+    """
+    super(LPIPS, self).__init__()
+    try:
+        import torchmetrics
+        self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')
+    except Exception as e:
+        logging.warning('torchmetrics is missing, consider installing by running "pip install torchmetrics"')
+        logging.warning(e)
+
+
+
+ +
+ +
+ + +

+ forward(predictions, targets) + +

+ + +
+ + + +

Parameters:

+
    +
  • + predictions + – +
    +
            The predicted images.
    +
    +
    +
  • +
  • + targets + – +
    +
            The ground truth images.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    The computed loss if successful, otherwise 0.0.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/learned_perceptual_losses.py +
def forward(self, predictions, targets):
+    """
+    Parameters
+    ----------
+    predictions   : torch.tensor
+                    The predicted images.
+    targets       : torch.tensor
+                    The ground truth images.
+
+    Returns
+    -------
+    result        : torch.tensor
+                    The computed loss if successful, otherwise 0.0.
+    """
+    try:
+        lpips_image = predictions
+        lpips_target = targets
+        if len(lpips_image.shape) == 3:
+            lpips_image = lpips_image.unsqueeze(0)
+            lpips_target = lpips_target.unsqueeze(0)
+        if lpips_image.shape[1] == 1:
+            lpips_image = lpips_image.repeat(1, 3, 1, 1)
+            lpips_target = lpips_target.repeat(1, 3, 1, 1)
+        lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)
+        lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)
+        l_LPIPS = self.lpips(lpips_image, lpips_target)
+        return l_LPIPS
+    except Exception as e:
+        logging.warning('LPIPS failed to compute.')
+        logging.warning(e)
+        return torch.tensor(0.0)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ MetamericLoss + + +

+ + +
+ + +

The MetamericLoss class provides a perceptual loss function.

+

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

+

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/metameric_loss.py +
class MetamericLoss():
+    """
+    The `MetamericLoss` class provides a perceptual loss function.
+
+    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.
+
+    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
+                 real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
+                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
+                 use_fullres_l0=False, equi=False):
+        """
+        Parameters
+        ----------
+
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        n_pyramid_levels        : int 
+                                    Number of levels of the steerable pyramid. Note that the image is padded
+                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                    too high will slow down the calculation a lot.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        n_orientations          : int 
+                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                    Increasing this will increase runtime.
+        use_l2_foveal_loss      : bool 
+                                    If true, for all the pixels that have pooling size 1 pixel in the 
+                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
+                                    In practice this gives better results when the loss is used for holography.
+        fovea_weight            : float 
+                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
+        use_radial_weight       : bool 
+                                    If True, will apply a radial weighting when calculating the difference between
+                                    the source and target stats maps. This weights stats closer to the fovea more than those
+                                    further away.
+        use_fullres_l0          : bool 
+                                    If true, stats for the lowpass residual are replaced with blurred versions
+                                    of the full-resolution source and target images.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.device = device
+        self.pyramid_maker = None
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.blurs = None
+        self.n_pyramid_levels = n_pyramid_levels
+        self.n_orientations = n_orientations
+        self.mode = mode
+        self.use_l2_foveal_loss = use_l2_foveal_loss
+        self.fovea_weight = fovea_weight
+        self.use_radial_weight = use_radial_weight
+        self.use_fullres_l0 = use_fullres_l0
+        self.equi = equi
+        if self.use_fullres_l0 and self.use_l2_foveal_loss:
+            raise Exception(
+                "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")
+
+    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,
+                       real_viewing_distance=0.6, mode="quadratic", equi=False):
+
+        if self.pyramid_maker is None or \
+                self.pyramid_maker.device != self.device or \
+                len(self.pyramid_maker.band_filters) != self.n_orientations or\
+                self.pyramid_maker.filt_h0.size(0) != image.size(1):
+            self.pyramid_maker = SpatialSteerablePyramid(
+                use_bilinear_downup=False, n_channels=image.size(1),
+                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)
+
+        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:
+            self.blurs = [RadiallyVaryingBlur()
+                          for i in range(self.n_pyramid_levels)]
+
+        def find_stats(image_pyr_level, blur):
+            image_means = blur.blur(
+                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
+            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,
+                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)
+
+            image_vars = image_meansq - (image_means*image_means)
+            image_vars[image_vars < 1e-7] = 1e-7
+            image_std = torch.sqrt(image_vars)
+            if torch.any(torch.isnan(image_means)):
+                print(image_means)
+                raise Exception("NaN in image means!")
+            if torch.any(torch.isnan(image_std)):
+                print(image_std)
+                raise Exception("NaN in image stdevs!")
+            if self.use_fullres_l0:
+                mask = blur.lod_map > 1e-6
+                mask = mask[None, None, ...]
+                if image_means.size(1) > 1:
+                    mask = mask.repeat(1, image_means.size(1), 1, 1)
+                matte = torch.zeros_like(image_means)
+                matte[mask] = 1.0
+                return image_means * matte, image_std * matte
+            return image_means, image_std
+        output_stats = []
+        image_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])
+        if self.use_l2_foveal_loss:
+            self.fovea_mask = torch.zeros(image.size(), device=image.device)
+            for i in range(self.fovea_mask.size(1)):
+                self.fovea_mask[0, i, ...] = 1.0 - \
+                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))
+                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0
+            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)
+            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode="area")
+            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode="bilinear")
+            periphery_mask = 1.0 - self.fovea_mask
+            self.periphery_mask = periphery_mask.clone()
+            output_stats.append(means * periphery_mask)
+            output_stats.append(variances * periphery_mask)
+        else:
+            output_stats.append(means)
+            output_stats.append(variances)
+
+        for l in range(0, len(image_pyramid)-1):
+            for o in range(len(image_pyramid[l]['b'])):
+                means, variances = find_stats(
+                    image_pyramid[l]['b'][o], self.blurs[l])
+                if self.use_l2_foveal_loss:
+                    output_stats.append(means * periphery_mask)
+                    output_stats.append(variances * periphery_mask)
+                else:
+                    output_stats.append(means)
+                    output_stats.append(variances)
+            if self.use_l2_foveal_loss:
+                periphery_mask = torch.nn.functional.interpolate(
+                    periphery_mask, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+
+        if self.use_l2_foveal_loss:
+            output_stats.append(image_pyramid[-1]["l"] * periphery_mask)
+        elif self.use_fullres_l0:
+            output_stats.append(self.blurs[0].blur(
+                image, alpha, real_image_width, real_viewing_distance, gaze, mode))
+        else:
+            output_stats.append(image_pyramid[-1]["l"])
+        return output_stats
+
+    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):
+        loss = 0.0
+        for a, b in zip(statsmap_a, statsmap_b):
+            if self.use_radial_weight:
+                radii = make_radial_map(
+                    [a.size(-2), a.size(-1)], gaze).to(a.device)
+                weights = 1.1 - (radii * radii * radii * radii)
+                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)
+                loss += torch.nn.MSELoss()(weights*a, weights*b)
+            else:
+                loss += torch.nn.MSELoss()(a, b)
+        loss /= len(statsmap_a)
+        return loss
+
+    def visualise_loss_map(self, image_stats):
+        loss_map = torch.zeros(image_stats[0].size()[-2:])
+        for i in range(len(image_stats)):
+            stats = image_stats[i]
+            target_stats = self.target_stats[i]
+            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
+            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
+            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            loss_map += stat_mse_map[0, 0, ...]
+        self.loss_map = loss_map
+
+    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
+        """ 
+        Calculates the Metameric Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        image_colorspace    : str
+                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                                accepted values: RGB, YCrCb.
+        gaze                : list
+                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+        visualise_loss      : bool
+                                Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamericLoss", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+        # If input is RGB, convert to YCrCb.
+        if image.size(1) == 3 and image_colorspace == "RGB":
+            image = rgb_2_ycrcb(image)
+            target = rgb_2_ycrcb(target)
+        if self.target is None:
+            self.target = torch.zeros(target.shape).to(target.device)
+        if type(target) == type(self.target):
+            if not torch.all(torch.eq(target, self.target)):
+                self.target = target.detach().clone()
+                self.target_stats = self.calc_statsmaps(
+                    self.target,
+                    gaze=gaze,
+                    alpha=self.alpha,
+                    real_image_width=self.real_image_width,
+                    real_viewing_distance=self.real_viewing_distance,
+                    mode=self.mode
+                )
+                self.target = target.detach().clone()
+            image_stats = self.calc_statsmaps(
+                image,
+                gaze=gaze,
+                alpha=self.alpha,
+                real_image_width=self.real_image_width,
+                real_viewing_distance=self.real_viewing_distance,
+                mode=self.mode
+            )
+            if visualise_loss:
+                self.visualise_loss_map(image_stats)
+            if self.use_l2_foveal_loss:
+                peripheral_loss = self.metameric_loss_stats(
+                    image_stats, self.target_stats, gaze)
+                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
+                # New weighting - evenly weight fovea and periphery.
+                loss = peripheral_loss + self.fovea_weight * foveal_loss
+            else:
+                loss = self.metameric_loss_stats(
+                    image_stats, self.target_stats, gaze)
+            return loss
+        else:
+            raise Exception("Target of incorrect type")
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False) + +

+ + +
+ +

Calculates the Metameric Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + image_colorspace + – +
    +
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
    +                accepted values: RGB, YCrCb.
    +
    +
    +
  • +
  • + gaze + – +
    +
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
  • + visualise_loss + – +
    +
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss.py +
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace="RGB", visualise_loss=False):
+    """ 
+    Calculates the Metameric Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    image_colorspace    : str
+                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                            accepted values: RGB, YCrCb.
+    gaze                : list
+                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+    visualise_loss      : bool
+                            Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamericLoss", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+    # If input is RGB, convert to YCrCb.
+    if image.size(1) == 3 and image_colorspace == "RGB":
+        image = rgb_2_ycrcb(image)
+        target = rgb_2_ycrcb(target)
+    if self.target is None:
+        self.target = torch.zeros(target.shape).to(target.device)
+    if type(target) == type(self.target):
+        if not torch.all(torch.eq(target, self.target)):
+            self.target = target.detach().clone()
+            self.target_stats = self.calc_statsmaps(
+                self.target,
+                gaze=gaze,
+                alpha=self.alpha,
+                real_image_width=self.real_image_width,
+                real_viewing_distance=self.real_viewing_distance,
+                mode=self.mode
+            )
+            self.target = target.detach().clone()
+        image_stats = self.calc_statsmaps(
+            image,
+            gaze=gaze,
+            alpha=self.alpha,
+            real_image_width=self.real_image_width,
+            real_viewing_distance=self.real_viewing_distance,
+            mode=self.mode
+        )
+        if visualise_loss:
+            self.visualise_loss_map(image_stats)
+        if self.use_l2_foveal_loss:
+            peripheral_loss = self.metameric_loss_stats(
+                image_stats, self.target_stats, gaze)
+            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)
+            # New weighting - evenly weight fovea and periphery.
+            loss = peripheral_loss + self.fovea_weight * foveal_loss
+        else:
+            loss = self.metameric_loss_stats(
+                image_stats, self.target_stats, gaze)
+        return loss
+    else:
+        raise Exception("Target of incorrect type")
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                        Number of levels of the steerable pyramid. Note that the image is padded
    +                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                    too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                    Increasing this will increase runtime.
    +
    +
    +
  • +
  • + use_l2_foveal_loss + – +
    +
                        If true, for all the pixels that have pooling size 1 pixel in the 
    +                    largest scale will use direct L2 against target rather than pooling over pyramid levels.
    +                    In practice this gives better results when the loss is used for holography.
    +
    +
    +
  • +
  • + fovea_weight + – +
    +
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
    +
    +
    +
  • +
  • + use_radial_weight + – +
    +
                        If True, will apply a radial weighting when calculating the difference between
    +                    the source and target stats maps. This weights stats closer to the fovea more than those
    +                    further away.
    +
    +
    +
  • +
  • + use_fullres_l0 + – +
    +
                        If true, stats for the lowpass residual are replaced with blurred versions
    +                    of the full-resolution source and target images.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,
+             real_viewing_distance=0.7, n_pyramid_levels=5, mode="quadratic",
+             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,
+             use_fullres_l0=False, equi=False):
+    """
+    Parameters
+    ----------
+
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    n_pyramid_levels        : int 
+                                Number of levels of the steerable pyramid. Note that the image is padded
+                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                too high will slow down the calculation a lot.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    n_orientations          : int 
+                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                Increasing this will increase runtime.
+    use_l2_foveal_loss      : bool 
+                                If true, for all the pixels that have pooling size 1 pixel in the 
+                                largest scale will use direct L2 against target rather than pooling over pyramid levels.
+                                In practice this gives better results when the loss is used for holography.
+    fovea_weight            : float 
+                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.
+    use_radial_weight       : bool 
+                                If True, will apply a radial weighting when calculating the difference between
+                                the source and target stats maps. This weights stats closer to the fovea more than those
+                                further away.
+    use_fullres_l0          : bool 
+                                If true, stats for the lowpass residual are replaced with blurred versions
+                                of the full-resolution source and target images.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.device = device
+    self.pyramid_maker = None
+    self.alpha = alpha
+    self.real_image_width = real_image_width
+    self.real_viewing_distance = real_viewing_distance
+    self.blurs = None
+    self.n_pyramid_levels = n_pyramid_levels
+    self.n_orientations = n_orientations
+    self.mode = mode
+    self.use_l2_foveal_loss = use_l2_foveal_loss
+    self.fovea_weight = fovea_weight
+    self.use_radial_weight = use_radial_weight
+    self.use_fullres_l0 = use_fullres_l0
+    self.equi = equi
+    if self.use_fullres_l0 and self.use_l2_foveal_loss:
+        raise Exception(
+            "Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ MetamericLossUniform + + +

+ + +
+ + +

Measures metameric loss between a given image and a metamer of the given target image. +This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

+ + + + + + +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
class MetamericLossUniform():
+    """
+    Measures metameric loss between a given image and a metamer of the given target image.
+    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.
+    """
+
+    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
+        """
+
+        Parameters
+        ----------
+        pooling_size            : int
+                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
+        n_pyramid_levels        : int 
+                                  Number of levels of the steerable pyramid. Note that the image is padded
+                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                  too high will slow down the calculation a lot.
+        n_orientations          : int 
+                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                  Increasing this will increase runtime.
+
+        """
+        self.target = None
+        self.device = device
+        self.pyramid_maker = None
+        self.pooling_size = pooling_size
+        self.n_pyramid_levels = n_pyramid_levels
+        self.n_orientations = n_orientations
+
+    def calc_statsmaps(self, image, pooling_size):
+
+        if self.pyramid_maker is None or \
+                self.pyramid_maker.device != self.device or \
+                len(self.pyramid_maker.band_filters) != self.n_orientations or\
+                self.pyramid_maker.filt_h0.size(0) != image.size(1):
+            self.pyramid_maker = SpatialSteerablePyramid(
+                use_bilinear_downup=False, n_channels=image.size(1),
+                device=self.device, n_orientations=self.n_orientations, filter_type="cropped", filter_size=5)
+
+
+        def find_stats(image_pyr_level, pooling_size):
+            image_means = uniform_blur(image_pyr_level, pooling_size)
+            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)
+            image_vars = image_meansq - (image_means*image_means)
+            image_vars[image_vars < 1e-7] = 1e-7
+            image_std = torch.sqrt(image_vars)
+            if torch.any(torch.isnan(image_means)):
+                print(image_means)
+                raise Exception("NaN in image means!")
+            if torch.any(torch.isnan(image_std)):
+                print(image_std)
+                raise Exception("NaN in image stdevs!")
+            return image_means, image_std
+
+        output_stats = []
+        image_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+        curr_pooling_size = pooling_size
+        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)
+        output_stats.append(means)
+        output_stats.append(variances)
+
+        for l in range(0, len(image_pyramid)-1):
+            for o in range(len(image_pyramid[l]['b'])):
+                means, variances = find_stats(
+                    image_pyramid[l]['b'][o], curr_pooling_size)
+                output_stats.append(means)
+                output_stats.append(variances)
+            curr_pooling_size /= 2
+
+        output_stats.append(image_pyramid[-1]["l"])
+        return output_stats
+
+    def metameric_loss_stats(self, statsmap_a, statsmap_b):
+        loss = 0.0
+        for a, b in zip(statsmap_a, statsmap_b):
+            loss += torch.nn.MSELoss()(a, b)
+        loss /= len(statsmap_a)
+        return loss
+
+    def visualise_loss_map(self, image_stats):
+        loss_map = torch.zeros(image_stats[0].size()[-2:])
+        for i in range(len(image_stats)):
+            stats = image_stats[i]
+            target_stats = self.target_stats[i]
+            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))
+            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(
+            ), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            loss_map += stat_mse_map[0, 0, ...]
+        self.loss_map = loss_map
+
+    def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
+        """ 
+        Calculates the Metameric Loss.
+
+        Parameters
+        ----------
+        image               : torch.tensor
+                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target              : torch.tensor
+                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        image_colorspace    : str
+                                The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                                accepted values: RGB, YCrCb.
+        visualise_loss      : bool
+                                Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamericLossUniform", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+        # If input is RGB, convert to YCrCb.
+        if image.size(1) == 3 and image_colorspace == "RGB":
+            image = rgb_2_ycrcb(image)
+            target = rgb_2_ycrcb(target)
+        if self.target is None:
+            self.target = torch.zeros(target.shape).to(target.device)
+        if type(target) == type(self.target):
+            if not torch.all(torch.eq(target, self.target)):
+                self.target = target.detach().clone()
+                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
+                self.target = target.detach().clone()
+            image_stats = self.calc_statsmaps(image, self.pooling_size)
+
+            if visualise_loss:
+                self.visualise_loss_map(image_stats)
+            loss = self.metameric_loss_stats(
+                image_stats, self.target_stats)
+            return loss
+        else:
+            raise Exception("Target of incorrect type")
+
+    def gen_metamer(self, image):
+        """ 
+        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+        This function can be used on its own to generate a metamer for a desired image.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+
+        Returns
+        -------
+        metamer : torch.tensor
+                  The generated metamer image
+        """
+        image = rgb_2_ycrcb(image)
+        image_size = image.size()
+        image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+
+        target_stats = self.calc_statsmaps(
+            image, self.pooling_size)
+        target_means = target_stats[::2]
+        target_stdevs = target_stats[1::2]
+        torch.manual_seed(0)
+        noise_image = torch.rand_like(image)
+        noise_pyramid = self.pyramid_maker.construct_pyramid(
+            noise_image, self.n_pyramid_levels)
+        input_pyramid = self.pyramid_maker.construct_pyramid(
+            image, self.n_pyramid_levels)
+
+        def match_level(input_level, target_mean, target_std):
+            level = input_level.clone()
+            level -= torch.mean(level)
+            input_std = torch.sqrt(torch.mean(level * level))
+            eps = 1e-6
+            # Safeguard against divide by zero
+            input_std[input_std < eps] = eps
+            level /= input_std
+            level *= target_std
+            level += target_mean
+            return level
+
+        nbands = len(noise_pyramid[0]["b"])
+        noise_pyramid[0]["h"] = match_level(
+            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+        for l in range(len(noise_pyramid)-1):
+            for b in range(nbands):
+                noise_pyramid[l]["b"][b] = match_level(
+                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+        metamer = self.pyramid_maker.reconstruct_from_pyramid(
+            noise_pyramid)
+        metamer = ycrcb_2_rgb(metamer)
+        # Crop to remove any padding
+        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+        return metamer
+
+    def to(self, device):
+        self.device = device
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, image_colorspace='RGB', visualise_loss=False) + +

+ + +
+ +

Calculates the Metameric Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + image_colorspace + – +
    +
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.
    +                accepted values: RGB, YCrCb.
    +
    +
    +
  • +
  • + visualise_loss + – +
    +
                    Shows a heatmap indicating which parts of the image contributed most to the loss.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
def __call__(self, image, target, image_colorspace="RGB", visualise_loss=False):
+    """ 
+    Calculates the Metameric Loss.
+
+    Parameters
+    ----------
+    image               : torch.tensor
+                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target              : torch.tensor
+                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    image_colorspace    : str
+                            The current colorspace of your image and target. Ignored if input does not have 3 channels.
+                            accepted values: RGB, YCrCb.
+    visualise_loss      : bool
+                            Shows a heatmap indicating which parts of the image contributed most to the loss. 
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamericLossUniform", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.n_pyramid_levels)
+    # If input is RGB, convert to YCrCb.
+    if image.size(1) == 3 and image_colorspace == "RGB":
+        image = rgb_2_ycrcb(image)
+        target = rgb_2_ycrcb(target)
+    if self.target is None:
+        self.target = torch.zeros(target.shape).to(target.device)
+    if type(target) == type(self.target):
+        if not torch.all(torch.eq(target, self.target)):
+            self.target = target.detach().clone()
+            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)
+            self.target = target.detach().clone()
+        image_stats = self.calc_statsmaps(image, self.pooling_size)
+
+        if visualise_loss:
+            self.visualise_loss_map(image_stats)
+        loss = self.metameric_loss_stats(
+            image_stats, self.target_stats)
+        return loss
+    else:
+        raise Exception("Target of incorrect type")
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2) + +

+ + +
+ + + +

Parameters:

+
    +
  • + pooling_size + – +
    +
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                      Number of levels of the steerable pyramid. Note that the image is padded
    +                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                  too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                  Increasing this will increase runtime.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):
+    """
+
+    Parameters
+    ----------
+    pooling_size            : int
+                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.
+    n_pyramid_levels        : int 
+                              Number of levels of the steerable pyramid. Note that the image is padded
+                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                              too high will slow down the calculation a lot.
+    n_orientations          : int 
+                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                              Increasing this will increase runtime.
+
+    """
+    self.target = None
+    self.device = device
+    self.pyramid_maker = None
+    self.pooling_size = pooling_size
+    self.n_pyramid_levels = n_pyramid_levels
+    self.n_orientations = n_orientations
+
+
+
+ +
+ +
+ + +

+ gen_metamer(image) + +

+ + +
+ +

Generates a metamer for an image, following the method in this paper +This function can be used on its own to generate a metamer for a desired image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +metamer ( tensor +) – +
    +

    The generated metamer image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metameric_loss_uniform.py +
def gen_metamer(self, image):
+    """ 
+    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+    This function can be used on its own to generate a metamer for a desired image.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+
+    Returns
+    -------
+    metamer : torch.tensor
+              The generated metamer image
+    """
+    image = rgb_2_ycrcb(image)
+    image_size = image.size()
+    image = pad_image_for_pyramid(image, self.n_pyramid_levels)
+
+    target_stats = self.calc_statsmaps(
+        image, self.pooling_size)
+    target_means = target_stats[::2]
+    target_stdevs = target_stats[1::2]
+    torch.manual_seed(0)
+    noise_image = torch.rand_like(image)
+    noise_pyramid = self.pyramid_maker.construct_pyramid(
+        noise_image, self.n_pyramid_levels)
+    input_pyramid = self.pyramid_maker.construct_pyramid(
+        image, self.n_pyramid_levels)
+
+    def match_level(input_level, target_mean, target_std):
+        level = input_level.clone()
+        level -= torch.mean(level)
+        input_std = torch.sqrt(torch.mean(level * level))
+        eps = 1e-6
+        # Safeguard against divide by zero
+        input_std[input_std < eps] = eps
+        level /= input_std
+        level *= target_std
+        level += target_mean
+        return level
+
+    nbands = len(noise_pyramid[0]["b"])
+    noise_pyramid[0]["h"] = match_level(
+        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+    for l in range(len(noise_pyramid)-1):
+        for b in range(nbands):
+            noise_pyramid[l]["b"][b] = match_level(
+                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+    metamer = self.pyramid_maker.reconstruct_from_pyramid(
+        noise_pyramid)
+    metamer = ycrcb_2_rgb(metamer)
+    # Crop to remove any padding
+    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+    return metamer
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ MetamerMSELoss + + +

+ + +
+ + +

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

+

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

+

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

+ + + + + + +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
class MetamerMSELoss():
+    """ 
+    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.
+
+    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.
+
+    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.
+    """
+
+
+    def __init__(self, device=torch.device("cpu"),
+                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
+                 n_pyramid_levels=5, n_orientations=2, equi=False):
+        """
+        Parameters
+        ----------
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+        n_pyramid_levels        : int 
+                                    Number of levels of the steerable pyramid. Note that the image is padded
+                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                    too high will slow down the calculation a lot.
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        n_orientations          : int 
+                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                    Increasing this will increase runtime.
+        equi                    : bool
+                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+        """
+        self.target = None
+        self.target_metamer = None
+        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
+                                            real_viewing_distance=real_viewing_distance,
+                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
+        self.loss_func = torch.nn.MSELoss()
+        self.noise = None
+
+    def gen_metamer(self, image, gaze):
+        """ 
+        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+        This function can be used on its own to generate a metamer for a desired image.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze    : list
+                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        metamer : torch.tensor
+                The generated metamer image
+        """
+        image = rgb_2_ycrcb(image)
+        image_size = image.size()
+        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+
+        target_stats = self.metameric_loss.calc_statsmaps(
+            image, gaze=gaze, alpha=self.metameric_loss.alpha)
+        target_means = target_stats[::2]
+        target_stdevs = target_stats[1::2]
+        if self.noise is None or self.noise.size() != image.size():
+            torch.manual_seed(0)
+            noise_image = torch.rand_like(image)
+        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+            noise_image, self.metameric_loss.n_pyramid_levels)
+        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+            image, self.metameric_loss.n_pyramid_levels)
+
+        def match_level(input_level, target_mean, target_std):
+            level = input_level.clone()
+            level -= torch.mean(level)
+            input_std = torch.sqrt(torch.mean(level * level))
+            eps = 1e-6
+            # Safeguard against divide by zero
+            input_std[input_std < eps] = eps
+            level /= input_std
+            level *= target_std
+            level += target_mean
+            return level
+
+        nbands = len(noise_pyramid[0]["b"])
+        noise_pyramid[0]["h"] = match_level(
+            noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+        for l in range(len(noise_pyramid)-1):
+            for b in range(nbands):
+                noise_pyramid[l]["b"][b] = match_level(
+                    noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+        noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
+            noise_pyramid)
+        metamer = ycrcb_2_rgb(metamer)
+        # Crop to remove any padding
+        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+        return metamer
+
+    def __call__(self, image, target, gaze=[0.5, 0.5]):
+        """ 
+        Calculates the Metamer MSE Loss.
+
+        Parameters
+        ----------
+        image   : torch.tensor
+                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        target  : torch.tensor
+                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+        gaze    : list
+                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+        Returns
+        -------
+
+        loss                : torch.tensor
+                                The computed loss.
+        """
+        check_loss_inputs("MetamerMSELoss", image, target)
+        # Pad image and target if necessary
+        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)
+
+        if target is not self.target or self.target is None:
+            self.target_metamer = self.gen_metamer(target, gaze)
+            self.target = target
+
+        return self.loss_func(image, self.target_metamer)
+
+    def to(self, device):
+        self.metameric_loss = self.metameric_loss.to(device)
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, gaze=[0.5, 0.5]) + +

+ + +
+ +

Calculates the Metamer MSE Loss.

+ + +

Parameters:

+
    +
  • + image + – +
    +
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + target + – +
    +
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
def __call__(self, image, target, gaze=[0.5, 0.5]):
+    """ 
+    Calculates the Metamer MSE Loss.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    target  : torch.tensor
+            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze    : list
+            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    loss                : torch.tensor
+                            The computed loss.
+    """
+    check_loss_inputs("MetamerMSELoss", image, target)
+    # Pad image and target if necessary
+    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)
+
+    if target is not self.target or self.target is None:
+        self.target_metamer = self.gen_metamer(target, gaze)
+        self.target = target
+
+    return self.loss_func(image, self.target_metamer)
+
+
+
+ +
+ +
+ + +

+ __init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +
    +
    +
  • +
  • + n_pyramid_levels + – +
    +
                        Number of levels of the steerable pyramid. Note that the image is padded
    +                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
    +                    too high will slow down the calculation a lot.
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
    +                    Increasing this will increase runtime.
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The gaze argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
def __init__(self, device=torch.device("cpu"),
+             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode="quadratic",
+             n_pyramid_levels=5, n_orientations=2, equi=False):
+    """
+    Parameters
+    ----------
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+    n_pyramid_levels        : int 
+                                Number of levels of the steerable pyramid. Note that the image is padded
+                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value
+                                too high will slow down the calculation a lot.
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    n_orientations          : int 
+                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.
+                                Increasing this will increase runtime.
+    equi                    : bool
+                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The gaze argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+    """
+    self.target = None
+    self.target_metamer = None
+    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,
+                                        real_viewing_distance=real_viewing_distance,
+                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)
+    self.loss_func = torch.nn.MSELoss()
+    self.noise = None
+
+
+
+ +
+ +
+ + +

+ gen_metamer(image, gaze) + +

+ + +
+ +

Generates a metamer for an image, following the method in this paper +This function can be used on its own to generate a metamer for a desired image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
    +
    +
    +
  • +
  • + gaze + – +
    +
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +metamer ( tensor +) – +
    +

    The generated metamer image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/metamer_mse_loss.py +
def gen_metamer(self, image, gaze):
+    """ 
+    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)
+    This function can be used on its own to generate a metamer for a desired image.
+
+    Parameters
+    ----------
+    image   : torch.tensor
+            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)
+    gaze    : list
+            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.
+
+    Returns
+    -------
+
+    metamer : torch.tensor
+            The generated metamer image
+    """
+    image = rgb_2_ycrcb(image)
+    image_size = image.size()
+    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)
+
+    target_stats = self.metameric_loss.calc_statsmaps(
+        image, gaze=gaze, alpha=self.metameric_loss.alpha)
+    target_means = target_stats[::2]
+    target_stdevs = target_stats[1::2]
+    if self.noise is None or self.noise.size() != image.size():
+        torch.manual_seed(0)
+        noise_image = torch.rand_like(image)
+    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+        noise_image, self.metameric_loss.n_pyramid_levels)
+    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(
+        image, self.metameric_loss.n_pyramid_levels)
+
+    def match_level(input_level, target_mean, target_std):
+        level = input_level.clone()
+        level -= torch.mean(level)
+        input_std = torch.sqrt(torch.mean(level * level))
+        eps = 1e-6
+        # Safeguard against divide by zero
+        input_std[input_std < eps] = eps
+        level /= input_std
+        level *= target_std
+        level += target_mean
+        return level
+
+    nbands = len(noise_pyramid[0]["b"])
+    noise_pyramid[0]["h"] = match_level(
+        noise_pyramid[0]["h"], target_means[0], target_stdevs[0])
+    for l in range(len(noise_pyramid)-1):
+        for b in range(nbands):
+            noise_pyramid[l]["b"][b] = match_level(
+                noise_pyramid[l]["b"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])
+    noise_pyramid[-1]["l"] = input_pyramid[-1]["l"]
+
+    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(
+        noise_pyramid)
+    metamer = ycrcb_2_rgb(metamer)
+    # Crop to remove any padding
+    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]
+    return metamer
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ RadiallyVaryingBlur + + +

+ + +
+ + +

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

+

The blur is accelerated by generating and sampling from MIP maps of the input image.

+

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

+

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

+ + + + + + +
+ Source code in odak/learn/perception/radially_varying_blur.py +
  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
class RadiallyVaryingBlur():
+    """ 
+
+    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().
+
+    The blur is accelerated by generating and sampling from MIP maps of the input image.
+
+    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.
+
+    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.
+
+    """
+
+    def __init__(self):
+        self.lod_map = None
+        self.equi = None
+
+    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
+        """
+        Apply the radially varying blur to an image.
+
+        Parameters
+        ----------
+
+        image                   : torch.tensor
+                                    The image to blur, in NCHW format.
+        alpha                   : float
+                                    parameter controlling foveation - larger values mean bigger pooling regions.
+        real_image_width        : float 
+                                    The real width of the image as displayed to the user.
+                                    Units don't matter as long as they are the same as for real_viewing_distance.
+                                    Ignored in equirectangular mode (equi==True)
+        real_viewing_distance   : float 
+                                    The real distance of the observer's eyes to the image plane.
+                                    Units don't matter as long as they are the same as for real_image_width.
+                                    Ignored in equirectangular mode (equi==True)
+        centre                  : tuple of floats
+                                    The centre of the radially varying blur (the gaze location).
+                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
+                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
+        mode                    : str 
+                                    Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                    as you move away from the fovea. We got best results with "quadratic".
+        equi                    : bool
+                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
+                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                    The centre argument is instead interpreted as gaze angles, and should be in the range
+                                    [-pi,pi]x[-pi/2,pi]
+
+        Returns
+        -------
+
+        output                  : torch.tensor
+                                    The blurred image
+        """
+        size = (image.size(-2), image.size(-1))
+
+        # LOD map caching
+        if self.lod_map is None or\
+                self.size != size or\
+                self.n_channels != image.size(1) or\
+                self.alpha != alpha or\
+                self.real_image_width != real_image_width or\
+                self.real_viewing_distance != real_viewing_distance or\
+                self.centre != centre or\
+                self.mode != mode or\
+                self.equi != equi:
+            if not equi:
+                self.lod_map = make_pooling_size_map_lod(
+                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
+            else:
+                self.lod_map = make_equi_pooling_size_map_lod(
+                    centre, (image.size(-2), image.size(-1)), alpha, mode)
+            self.size = size
+            self.n_channels = image.size(1)
+            self.alpha = alpha
+            self.real_image_width = real_image_width
+            self.real_viewing_distance = real_viewing_distance
+            self.centre = centre
+            self.lod_map = self.lod_map.to(image.device)
+            self.lod_fraction = torch.fmod(self.lod_map, 1.0)
+            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
+                1, image.size(1), 1, 1)
+            self.mode = mode
+            self.equi = equi
+
+        if self.lod_map.device != image.device:
+            self.lod_map = self.lod_map.to(image.device)
+        if self.lod_fraction.device != image.device:
+            self.lod_fraction = self.lod_fraction.to(image.device)
+
+        mipmap = [image]
+        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
+            mipmap.append(torch.nn.functional.interpolate(
+                mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
+        if mipmap[-1].size(-1) == 2:
+            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
+            mipmap.append(final_mip)
+        if mipmap[-1].size(-2) == 2:
+            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
+            mipmap.append(final_mip)
+
+        for l in range(len(mipmap)):
+            if l == len(mipmap)-1:
+                mipmap[l] = mipmap[l] * \
+                    torch.ones(image.size(), device=image.device)
+            else:
+                for l2 in range(l-1, -1, -1):
+                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
+                        image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+
+        output = torch.zeros(image.size(), device=image.device)
+        for l in range(len(mipmap)):
+            if l == 0:
+                mask = self.lod_map < (l+1)
+            elif l == len(mipmap)-1:
+                mask = self.lod_map >= l
+            else:
+                mask = torch.logical_and(
+                    self.lod_map >= l, self.lod_map < (l+1))
+
+            if l == len(mipmap)-1:
+                blended_levels = mipmap[l]
+            else:
+                blended_levels = (1 - self.lod_fraction) * \
+                    mipmap[l] + self.lod_fraction*mipmap[l+1]
+            mask = mask[None, None, ...]
+            mask = mask.repeat(1, image.size(1), 1, 1)
+            output[mask] = blended_levels[mask]
+
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False) + +

+ + +
+ +

Apply the radially varying blur to an image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                        The image to blur, in NCHW format.
    +
    +
    +
  • +
  • + alpha + – +
    +
                        parameter controlling foveation - larger values mean bigger pooling regions.
    +
    +
    +
  • +
  • + real_image_width + – +
    +
                        The real width of the image as displayed to the user.
    +                    Units don't matter as long as they are the same as for real_viewing_distance.
    +                    Ignored in equirectangular mode (equi==True)
    +
    +
    +
  • +
  • + real_viewing_distance + – +
    +
                        The real distance of the observer's eyes to the image plane.
    +                    Units don't matter as long as they are the same as for real_image_width.
    +                    Ignored in equirectangular mode (equi==True)
    +
    +
    +
  • +
  • + centre + – +
    +
                        The centre of the radially varying blur (the gaze location).
    +                    Should be a tuple of floats containing normalised image coordinates in range [0,1]
    +                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
    +
    +
    +
  • +
  • + mode + – +
    +
                        Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
    +                    as you move away from the fovea. We got best results with "quadratic".
    +
    +
    +
  • +
  • + equi + – +
    +
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
    +                    format 360 image. The settings real_image_width and real_viewing distance are ignored.
    +                    The centre argument is instead interpreted as gaze angles, and should be in the range
    +                    [-pi,pi]x[-pi/2,pi]
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    The blurred image

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/radially_varying_blur.py +
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode="quadratic", equi=False):
+    """
+    Apply the radially varying blur to an image.
+
+    Parameters
+    ----------
+
+    image                   : torch.tensor
+                                The image to blur, in NCHW format.
+    alpha                   : float
+                                parameter controlling foveation - larger values mean bigger pooling regions.
+    real_image_width        : float 
+                                The real width of the image as displayed to the user.
+                                Units don't matter as long as they are the same as for real_viewing_distance.
+                                Ignored in equirectangular mode (equi==True)
+    real_viewing_distance   : float 
+                                The real distance of the observer's eyes to the image plane.
+                                Units don't matter as long as they are the same as for real_image_width.
+                                Ignored in equirectangular mode (equi==True)
+    centre                  : tuple of floats
+                                The centre of the radially varying blur (the gaze location).
+                                Should be a tuple of floats containing normalised image coordinates in range [0,1]
+                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]
+    mode                    : str 
+                                Foveation mode, either "quadratic" or "linear". Controls how pooling regions grow
+                                as you move away from the fovea. We got best results with "quadratic".
+    equi                    : bool
+                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular
+                                format 360 image. The settings real_image_width and real_viewing distance are ignored.
+                                The centre argument is instead interpreted as gaze angles, and should be in the range
+                                [-pi,pi]x[-pi/2,pi]
+
+    Returns
+    -------
+
+    output                  : torch.tensor
+                                The blurred image
+    """
+    size = (image.size(-2), image.size(-1))
+
+    # LOD map caching
+    if self.lod_map is None or\
+            self.size != size or\
+            self.n_channels != image.size(1) or\
+            self.alpha != alpha or\
+            self.real_image_width != real_image_width or\
+            self.real_viewing_distance != real_viewing_distance or\
+            self.centre != centre or\
+            self.mode != mode or\
+            self.equi != equi:
+        if not equi:
+            self.lod_map = make_pooling_size_map_lod(
+                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)
+        else:
+            self.lod_map = make_equi_pooling_size_map_lod(
+                centre, (image.size(-2), image.size(-1)), alpha, mode)
+        self.size = size
+        self.n_channels = image.size(1)
+        self.alpha = alpha
+        self.real_image_width = real_image_width
+        self.real_viewing_distance = real_viewing_distance
+        self.centre = centre
+        self.lod_map = self.lod_map.to(image.device)
+        self.lod_fraction = torch.fmod(self.lod_map, 1.0)
+        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(
+            1, image.size(1), 1, 1)
+        self.mode = mode
+        self.equi = equi
+
+    if self.lod_map.device != image.device:
+        self.lod_map = self.lod_map.to(image.device)
+    if self.lod_fraction.device != image.device:
+        self.lod_fraction = self.lod_fraction.to(image.device)
+
+    mipmap = [image]
+    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:
+        mipmap.append(torch.nn.functional.interpolate(
+            mipmap[-1], scale_factor=0.5, mode="area", recompute_scale_factor=False))
+    if mipmap[-1].size(-1) == 2:
+        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]
+        mipmap.append(final_mip)
+    if mipmap[-1].size(-2) == 2:
+        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]
+        mipmap.append(final_mip)
+
+    for l in range(len(mipmap)):
+        if l == len(mipmap)-1:
+            mipmap[l] = mipmap[l] * \
+                torch.ones(image.size(), device=image.device)
+        else:
+            for l2 in range(l-1, -1, -1):
+                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(
+                    image.size(-2), image.size(-1)), mode="bilinear", align_corners=False, recompute_scale_factor=False)
+
+    output = torch.zeros(image.size(), device=image.device)
+    for l in range(len(mipmap)):
+        if l == 0:
+            mask = self.lod_map < (l+1)
+        elif l == len(mipmap)-1:
+            mask = self.lod_map >= l
+        else:
+            mask = torch.logical_and(
+                self.lod_map >= l, self.lod_map < (l+1))
+
+        if l == len(mipmap)-1:
+            blended_levels = mipmap[l]
+        else:
+            blended_levels = (1 - self.lod_fraction) * \
+                mipmap[l] + self.lod_fraction*mipmap[l+1]
+        mask = mask[None, None, ...]
+        mask = mask.repeat(1, image.size(1), 1, 1)
+        output[mask] = blended_levels[mask]
+
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ SpatialSteerablePyramid + + +

+ + +
+ + +

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) +as opposed to multiplication in the Fourier domain. +This has a number of optimisations over previous implementations that increase efficiency, but introduce some +reconstruction error.

+ + + + + + +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
class SpatialSteerablePyramid():
+    """
+    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)
+    as opposed to multiplication in the Fourier domain.
+    This has a number of optimisations over previous implementations that increase efficiency, but introduce some
+    reconstruction error.
+    """
+
+
+    def __init__(self, use_bilinear_downup=True, n_channels=1,
+                 filter_size=9, n_orientations=6, filter_type="full",
+                 device=torch.device('cpu')):
+        """
+        Parameters
+        ----------
+
+        use_bilinear_downup     : bool
+                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach
+                                    of applying a large lowpass kernel and sampling even rows/columns
+        n_channels              : int
+                                    Number of channels in the input images (e.g. 3 for RGB input)
+        filter_size             : int
+                                    Desired size of filters (e.g. 3 will use 3x3 filters).
+        n_orientations          : int
+                                    Number of oriented bands in each level of the pyramid.
+        filter_type             : str
+                                    This can be used to select smaller filters than the original ones if desired.
+                                    full: Original filter sizes
+                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+        device                  : torch.device
+                                    torch device the input images will be supplied from.
+        """
+        self.use_bilinear_downup = use_bilinear_downup
+        self.device = device
+
+        filters = get_steerable_pyramid_filters(
+            filter_size, n_orientations, filter_type)
+
+        def make_pad(filter):
+            filter_size = filter.size(-1)
+            pad_amt = (filter_size-1) // 2
+            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))
+
+        if not self.use_bilinear_downup:
+            self.filt_l = filters["l"].to(device)
+            self.pad_l = make_pad(self.filt_l)
+        self.filt_l0 = filters["l0"].to(device)
+        self.pad_l0 = make_pad(self.filt_l0)
+        self.filt_h0 = filters["h0"].to(device)
+        self.pad_h0 = make_pad(self.filt_h0)
+        for b in range(len(filters["b"])):
+            filters["b"][b] = filters["b"][b].to(device)
+        self.band_filters = filters["b"]
+        self.pad_b = make_pad(self.band_filters[0])
+
+        if n_channels != 1:
+            def add_channels_to_filter(filter):
+                padded = torch.zeros(n_channels, n_channels, filter.size()[
+                                     2], filter.size()[3]).to(device)
+                for channel in range(n_channels):
+                    padded[channel, channel, :, :] = filter
+                return padded
+            self.filt_h0 = add_channels_to_filter(self.filt_h0)
+            for b in range(len(self.band_filters)):
+                self.band_filters[b] = add_channels_to_filter(
+                    self.band_filters[b])
+            self.filt_l0 = add_channels_to_filter(self.filt_l0)
+            if not self.use_bilinear_downup:
+                self.filt_l = add_channels_to_filter(self.filt_l)
+
+    def construct_pyramid(self, image, n_levels, multiple_highpass=False):
+        """
+        Constructs and returns a steerable pyramid for the provided image.
+
+        Parameters
+        ----------
+
+        image               : torch.tensor
+                                The input image, in NCHW format. The number of channels C should match num_channels
+                                when the pyramid maker was created.
+        n_levels            : int
+                                Number of levels in the constructed steerable pyramid.
+        multiple_highpass   : bool
+                                If true, computes a highpass for each level of the pyramid.
+                                These extra levels are redundant (not used for reconstruction).
+
+        Returns
+        -------
+
+        pyramid             : list of dicts of torch.tensor
+                                The computed steerable pyramid.
+                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
+                                Each level is stored as a dict, with the following keys:
+                                "h" Highpass residual
+                                "l" Lowpass residual
+                                "b" Oriented bands (a list of torch.tensor)
+        """
+        pyramid = []
+
+        # Make level 0, containing highpass, lowpass and the bands
+        level0 = {}
+        level0['h'] = torch.nn.functional.conv2d(
+            self.pad_h0(image), self.filt_h0)
+        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+        level0['l'] = lowpass.clone()
+        bands = []
+        for filt_b in self.band_filters:
+            bands.append(torch.nn.functional.conv2d(
+                self.pad_b(lowpass), filt_b))
+        level0['b'] = bands
+        pyramid.append(level0)
+
+        # Make intermediate levels
+        for l in range(n_levels-2):
+            level = {}
+            if self.use_bilinear_downup:
+                lowpass = torch.nn.functional.interpolate(
+                    lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+            else:
+                lowpass = torch.nn.functional.conv2d(
+                    self.pad_l(lowpass), self.filt_l)
+                lowpass = lowpass[:, :, ::2, ::2]
+            level['l'] = lowpass.clone()
+            bands = []
+            for filt_b in self.band_filters:
+                bands.append(torch.nn.functional.conv2d(
+                    self.pad_b(lowpass), filt_b))
+            level['b'] = bands
+            if multiple_highpass:
+                level['h'] = torch.nn.functional.conv2d(
+                    self.pad_h0(lowpass), self.filt_h0)
+            pyramid.append(level)
+
+        # Make final level (lowpass residual)
+        level = {}
+        if self.use_bilinear_downup:
+            lowpass = torch.nn.functional.interpolate(
+                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+        else:
+            lowpass = torch.nn.functional.conv2d(
+                self.pad_l(lowpass), self.filt_l)
+            lowpass = lowpass[:, :, ::2, ::2]
+        level['l'] = lowpass
+        pyramid.append(level)
+
+        return pyramid
+
+    def reconstruct_from_pyramid(self, pyramid):
+        """
+        Reconstructs an input image from a steerable pyramid.
+
+        Parameters
+        ----------
+
+        pyramid : list of dicts of torch.tensor
+                    The steerable pyramid.
+                    Should be in the same format as output by construct_steerable_pyramid().
+                    The number of channels should match num_channels when the pyramid maker was created.
+
+        Returns
+        -------
+
+        image   : torch.tensor
+                    The reconstructed image, in NCHW format.         
+        """
+        def upsample(image, size):
+            if self.use_bilinear_downup:
+                return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+            else:
+                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
+                                    2]*2, image.size()[3]*2)).to(self.device)
+                zeros[:, :, ::2, ::2] = image
+                zeros = torch.nn.functional.conv2d(
+                    self.pad_l(zeros), self.filt_l)
+                return zeros
+
+        image = pyramid[-1]['l']
+        for level in reversed(pyramid[:-1]):
+            image = upsample(image, level['b'][0].size()[2:])
+            for b in range(len(level['b'])):
+                b_filtered = torch.nn.functional.conv2d(
+                    self.pad_b(level['b'][b]), -self.band_filters[b])
+                image += b_filtered
+
+        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+        image += torch.nn.functional.conv2d(
+            self.pad_h0(pyramid[0]['h']), self.filt_h0)
+
+        return image
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + use_bilinear_downup + – +
    +
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach
    +                    of applying a large lowpass kernel and sampling even rows/columns
    +
    +
    +
  • +
  • + n_channels + – +
    +
                        Number of channels in the input images (e.g. 3 for RGB input)
    +
    +
    +
  • +
  • + filter_size + – +
    +
                        Desired size of filters (e.g. 3 will use 3x3 filters).
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                        Number of oriented bands in each level of the pyramid.
    +
    +
    +
  • +
  • + filter_type + – +
    +
                        This can be used to select smaller filters than the original ones if desired.
    +                    full: Original filter sizes
    +                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
    +                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    +
    +
    +
  • +
  • + device + – +
    +
                        torch device the input images will be supplied from.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def __init__(self, use_bilinear_downup=True, n_channels=1,
+             filter_size=9, n_orientations=6, filter_type="full",
+             device=torch.device('cpu')):
+    """
+    Parameters
+    ----------
+
+    use_bilinear_downup     : bool
+                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach
+                                of applying a large lowpass kernel and sampling even rows/columns
+    n_channels              : int
+                                Number of channels in the input images (e.g. 3 for RGB input)
+    filter_size             : int
+                                Desired size of filters (e.g. 3 will use 3x3 filters).
+    n_orientations          : int
+                                Number of oriented bands in each level of the pyramid.
+    filter_type             : str
+                                This can be used to select smaller filters than the original ones if desired.
+                                full: Original filter sizes
+                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+    device                  : torch.device
+                                torch device the input images will be supplied from.
+    """
+    self.use_bilinear_downup = use_bilinear_downup
+    self.device = device
+
+    filters = get_steerable_pyramid_filters(
+        filter_size, n_orientations, filter_type)
+
+    def make_pad(filter):
+        filter_size = filter.size(-1)
+        pad_amt = (filter_size-1) // 2
+        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))
+
+    if not self.use_bilinear_downup:
+        self.filt_l = filters["l"].to(device)
+        self.pad_l = make_pad(self.filt_l)
+    self.filt_l0 = filters["l0"].to(device)
+    self.pad_l0 = make_pad(self.filt_l0)
+    self.filt_h0 = filters["h0"].to(device)
+    self.pad_h0 = make_pad(self.filt_h0)
+    for b in range(len(filters["b"])):
+        filters["b"][b] = filters["b"][b].to(device)
+    self.band_filters = filters["b"]
+    self.pad_b = make_pad(self.band_filters[0])
+
+    if n_channels != 1:
+        def add_channels_to_filter(filter):
+            padded = torch.zeros(n_channels, n_channels, filter.size()[
+                                 2], filter.size()[3]).to(device)
+            for channel in range(n_channels):
+                padded[channel, channel, :, :] = filter
+            return padded
+        self.filt_h0 = add_channels_to_filter(self.filt_h0)
+        for b in range(len(self.band_filters)):
+            self.band_filters[b] = add_channels_to_filter(
+                self.band_filters[b])
+        self.filt_l0 = add_channels_to_filter(self.filt_l0)
+        if not self.use_bilinear_downup:
+            self.filt_l = add_channels_to_filter(self.filt_l)
+
+
+
+ +
+ +
+ + +

+ construct_pyramid(image, n_levels, multiple_highpass=False) + +

+ + +
+ +

Constructs and returns a steerable pyramid for the provided image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
                    The input image, in NCHW format. The number of channels C should match num_channels
    +                when the pyramid maker was created.
    +
    +
    +
  • +
  • + n_levels + – +
    +
                    Number of levels in the constructed steerable pyramid.
    +
    +
    +
  • +
  • + multiple_highpass + – +
    +
                    If true, computes a highpass for each level of the pyramid.
    +                These extra levels are redundant (not used for reconstruction).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pyramid ( list of dicts of torch.tensor +) – +
    +

    The computed steerable pyramid. +Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. +Each level is stored as a dict, with the following keys: +"h" Highpass residual +"l" Lowpass residual +"b" Oriented bands (a list of torch.tensor)

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def construct_pyramid(self, image, n_levels, multiple_highpass=False):
+    """
+    Constructs and returns a steerable pyramid for the provided image.
+
+    Parameters
+    ----------
+
+    image               : torch.tensor
+                            The input image, in NCHW format. The number of channels C should match num_channels
+                            when the pyramid maker was created.
+    n_levels            : int
+                            Number of levels in the constructed steerable pyramid.
+    multiple_highpass   : bool
+                            If true, computes a highpass for each level of the pyramid.
+                            These extra levels are redundant (not used for reconstruction).
+
+    Returns
+    -------
+
+    pyramid             : list of dicts of torch.tensor
+                            The computed steerable pyramid.
+                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.
+                            Each level is stored as a dict, with the following keys:
+                            "h" Highpass residual
+                            "l" Lowpass residual
+                            "b" Oriented bands (a list of torch.tensor)
+    """
+    pyramid = []
+
+    # Make level 0, containing highpass, lowpass and the bands
+    level0 = {}
+    level0['h'] = torch.nn.functional.conv2d(
+        self.pad_h0(image), self.filt_h0)
+    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+    level0['l'] = lowpass.clone()
+    bands = []
+    for filt_b in self.band_filters:
+        bands.append(torch.nn.functional.conv2d(
+            self.pad_b(lowpass), filt_b))
+    level0['b'] = bands
+    pyramid.append(level0)
+
+    # Make intermediate levels
+    for l in range(n_levels-2):
+        level = {}
+        if self.use_bilinear_downup:
+            lowpass = torch.nn.functional.interpolate(
+                lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+        else:
+            lowpass = torch.nn.functional.conv2d(
+                self.pad_l(lowpass), self.filt_l)
+            lowpass = lowpass[:, :, ::2, ::2]
+        level['l'] = lowpass.clone()
+        bands = []
+        for filt_b in self.band_filters:
+            bands.append(torch.nn.functional.conv2d(
+                self.pad_b(lowpass), filt_b))
+        level['b'] = bands
+        if multiple_highpass:
+            level['h'] = torch.nn.functional.conv2d(
+                self.pad_h0(lowpass), self.filt_h0)
+        pyramid.append(level)
+
+    # Make final level (lowpass residual)
+    level = {}
+    if self.use_bilinear_downup:
+        lowpass = torch.nn.functional.interpolate(
+            lowpass, scale_factor=0.5, mode="area", recompute_scale_factor=False)
+    else:
+        lowpass = torch.nn.functional.conv2d(
+            self.pad_l(lowpass), self.filt_l)
+        lowpass = lowpass[:, :, ::2, ::2]
+    level['l'] = lowpass
+    pyramid.append(level)
+
+    return pyramid
+
+
+
+ +
+ +
+ + +

+ reconstruct_from_pyramid(pyramid) + +

+ + +
+ +

Reconstructs an input image from a steerable pyramid.

+ + +

Parameters:

+
    +
  • + pyramid + (list of dicts of torch.tensor) + – +
    +
        The steerable pyramid.
    +    Should be in the same format as output by construct_steerable_pyramid().
    +    The number of channels should match num_channels when the pyramid maker was created.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( tensor +) – +
    +

    The reconstructed image, in NCHW format.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
def reconstruct_from_pyramid(self, pyramid):
+    """
+    Reconstructs an input image from a steerable pyramid.
+
+    Parameters
+    ----------
+
+    pyramid : list of dicts of torch.tensor
+                The steerable pyramid.
+                Should be in the same format as output by construct_steerable_pyramid().
+                The number of channels should match num_channels when the pyramid maker was created.
+
+    Returns
+    -------
+
+    image   : torch.tensor
+                The reconstructed image, in NCHW format.         
+    """
+    def upsample(image, size):
+        if self.use_bilinear_downup:
+            return torch.nn.functional.interpolate(image, size=size, mode="bilinear", align_corners=False, recompute_scale_factor=False)
+        else:
+            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[
+                                2]*2, image.size()[3]*2)).to(self.device)
+            zeros[:, :, ::2, ::2] = image
+            zeros = torch.nn.functional.conv2d(
+                self.pad_l(zeros), self.filt_l)
+            return zeros
+
+    image = pyramid[-1]['l']
+    for level in reversed(pyramid[:-1]):
+        image = upsample(image, level['b'][0].size()[2:])
+        for b in range(len(level['b'])):
+            b_filtered = torch.nn.functional.conv2d(
+                self.pad_b(level['b'][b]), -self.band_filters[b])
+            image += b_filtered
+
+    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)
+    image += torch.nn.functional.conv2d(
+        self.pad_h0(pyramid[0]['h']), self.filt_h0)
+
+    return image
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ pad_image_for_pyramid(image, n_pyramid_levels) + +

+ + +
+ +

Pads an image to the extent necessary to compute a steerable pyramid of the input image. +This involves padding so both height and width are divisible by 2**n_pyramid_levels. +Uses reflection padding.

+ + +

Parameters:

+
    +
  • + image + – +
    +

    Image to pad, in NCHW format

    +
    +
  • +
  • + n_pyramid_levels + – +
    +

    Number of levels in the pyramid you plan to construct.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/spatial_steerable_pyramid.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
def pad_image_for_pyramid(image, n_pyramid_levels):
+    """
+    Pads an image to the extent necessary to compute a steerable pyramid of the input image.
+    This involves padding so both height and width are divisible by 2**n_pyramid_levels.
+    Uses reflection padding.
+
+    Parameters
+    ----------
+
+    image: torch.tensor
+        Image to pad, in NCHW format
+    n_pyramid_levels: int
+        Number of levels in the pyramid you plan to construct.
+    """
+    min_divisor = 2 ** n_pyramid_levels
+    height = image.size(2)
+    width = image.size(3)
+    required_height = math.ceil(height / min_divisor) * min_divisor
+    required_width = math.ceil(width / min_divisor) * min_divisor
+    if required_height > height or required_width > width:
+        # We need to pad!
+        pad = torch.nn.ReflectionPad2d(
+            (0, 0, required_height-height, required_width-width))
+        return pad(image)
+    return image
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ crop_steerable_pyramid_filters(filters, size) + +

+ + +
+ +

Given original 9x9 NYU filters, this crops them to the desired size. +The size must be an odd number >= 3 +Note this only crops the h0, l0 and band filters (not the l downsampling filter)

+ + +

Parameters:

+
    +
  • + filters + – +
    +
            Filters to crop (should in format used by get_steerable_pyramid_filters.)
    +
    +
    +
  • +
  • + size + – +
    +
            Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +filters ( dict of torch.tensor +) – +
    +

    The cropped filters.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/steerable_pyramid_filters.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def crop_steerable_pyramid_filters(filters, size):
+    """
+    Given original 9x9 NYU filters, this crops them to the desired size.
+    The size must be an odd number >= 3
+    Note this only crops the h0, l0 and band filters (not the l downsampling filter)
+
+    Parameters
+    ----------
+    filters     : dict of torch.tensor
+                    Filters to crop (should in format used by get_steerable_pyramid_filters.)
+    size        : int
+                    Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.
+
+    Returns
+    -------
+    filters     : dict of torch.tensor
+                    The cropped filters.
+    """
+    assert(size >= 3)
+    assert(size % 2 == 1)
+    r = (size-1) // 2
+
+    def crop_filter(filter, r, normalise=True):
+        r2 = (filter.size(-1)-1)//2
+        filter = filter[:, :, r2-r:r2+r+1, r2-r:r2+r+1]
+        if normalise:
+            filter -= torch.sum(filter)
+        return filter
+
+    filters["h0"] = crop_filter(filters["h0"], r, normalise=False)
+    sum_l = torch.sum(filters["l"])
+    filters["l"] = crop_filter(filters["l"], 6, normalise=False)
+    filters["l"] *= sum_l / torch.sum(filters["l"])
+    sum_l0 = torch.sum(filters["l0"])
+    filters["l0"] = crop_filter(filters["l0"], 2, normalise=False)
+    filters["l0"] *= sum_l0 / torch.sum(filters["l0"])
+    for b in range(len(filters["b"])):
+        filters["b"][b] = crop_filter(filters["b"][b], r, normalise=True)
+    return filters
+
+
+
+ +
+ +
+ + +

+ get_steerable_pyramid_filters(size, n_orientations, filter_type) + +

+ + +
+ +

This returns filters for a real-valued steerable pyramid.

+ + +

Parameters:

+
    +
  • + size + – +
    +
                Width of the filters (e.g. 3 will return 3x3 filters)
    +
    +
    +
  • +
  • + n_orientations + – +
    +
                Number of oriented band filters
    +
    +
    +
  • +
  • + filter_type + – +
    +
                This can be used to select between the original NYU filters and cropped or trained alternatives.
    +            full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
    +            cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
    +            trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +filters ( dict of torch.tensor +) – +
    +

    The steerable pyramid filters. Returned as a dict with the following keys: +"l" The lowpass downsampling filter +"l0" The lowpass residual filter +"h0" The highpass residual filter +"b" The band filters (a list of torch.tensor filters, one for each orientation).

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/steerable_pyramid_filters.py +
 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
def get_steerable_pyramid_filters(size, n_orientations, filter_type):
+    """
+    This returns filters for a real-valued steerable pyramid.
+
+    Parameters
+    ----------
+
+    size            : int
+                        Width of the filters (e.g. 3 will return 3x3 filters)
+    n_orientations  : int
+                        Number of oriented band filters
+    filter_type     :  str
+                        This can be used to select between the original NYU filters and cropped or trained alternatives.
+                        full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py
+                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.
+                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.
+
+    Returns
+    -------
+    filters         : dict of torch.tensor
+                        The steerable pyramid filters. Returned as a dict with the following keys:
+                        "l" The lowpass downsampling filter
+                        "l0" The lowpass residual filter
+                        "h0" The highpass residual filter
+                        "b" The band filters (a list of torch.tensor filters, one for each orientation).
+    """
+
+    if filter_type != "full" and filter_type != "cropped" and filter_type != "trained":
+        raise Exception(
+            "Unknown filter type %s! Only filter types are full, cropped or trained." % filter_type)
+
+    filters = {}
+    if n_orientations == 1:
+        filters["l"] = torch.tensor([
+            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -
+                1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04],
+            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
+                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
+            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
+                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
+            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
+                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
+            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
+                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
+            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
+                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
+            [-1.871920e-03, -6.948900e-03, -3.781600e-03, 2.449600e-02, 7.624674e-02, 1.348999e-01, 1.576508e-01,
+                1.348999e-01, 7.624674e-02, 2.449600e-02, -3.781600e-03, -6.948900e-03, -1.871920e-03],
+            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,
+                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],
+            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,
+                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],
+            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,
+                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],
+            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -
+                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],
+            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -
+                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],
+            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04]]
+        ).reshape(1, 1, 13, 13)
+        filters["l0"] = torch.tensor([
+            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -
+                3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04],
+            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
+                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
+            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
+                6.441488e-02, -1.344160e-02, -3.725800e-04],
+            [-3.743860e-03, -7.563200e-03, 1.524935e-01, 3.153017e-01,
+                1.524935e-01, -7.563200e-03, -3.743860e-03],
+            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,
+                6.441488e-02, -1.344160e-02, -3.725800e-04],
+            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -
+                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],
+            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04]]
+        ).reshape(1, 1, 7, 7)
+        filters["h0"] = torch.tensor([
+            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -
+                2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04],
+            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
+                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
+            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
+                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
+            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
+                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
+            [-2.406600e-04, -3.732100e-04, -2.420138e-02, -9.623594e-02,
+                8.554893e-01, -9.623594e-02, -2.420138e-02, -3.732100e-04, -2.406600e-04],
+            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -
+                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],
+            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -
+                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],
+            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -
+                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],
+            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor([
+            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -
+            1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05,
+            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
+            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
+            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
+            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
+            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
+            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
+            -1.009473e-02, -9.091950e-03, -3.487008e-02, -3.158605e-02, 9.464195e-01, -
+            3.158605e-02, -3.487008e-02, -9.091950e-03, -1.009473e-02,
+            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -
+            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,
+            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -
+            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,
+            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -
+            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,
+            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 2:
+        filters["l"] = torch.tensor(
+            [[-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05],
+             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
+                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
+             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
+                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
+             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
+                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
+             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
+                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
+             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
+                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
+             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
+                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
+             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
+                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
+             [1.262000e-03, 5.589400e-03, 5.538200e-04, -9.637220e-03, -1.648568e-02, 1.438512e-02, 9.058014e-02, 1.773651e-01, 2.120374e-01,
+                 1.773651e-01, 9.058014e-02, 1.438512e-02, -1.648568e-02, -9.637220e-03, 5.538200e-04, 5.589400e-03, 1.262000e-03],
+             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,
+                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],
+             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,
+                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],
+             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,
+                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],
+             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -
+                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],
+             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -
+                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],
+             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,
+                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],
+             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,
+                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],
+             [-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05]]
+        ).reshape(1, 1, 17, 17)
+        filters["l0"] = torch.tensor(
+            [[-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05],
+             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
+                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
+             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
+                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
+             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
+                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
+             [2.524010e-03, 1.107620e-03, -3.297137e-02, 1.811603e-01, 4.376250e-01,
+                 1.811603e-01, -3.297137e-02, 1.107620e-03, 2.524010e-03],
+             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,
+                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],
+             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -
+                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],
+             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,
+                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],
+             [-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05]]
+        ).reshape(1, 1, 9, 9)
+        filters["h0"] = torch.tensor(
+            [[-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04],
+             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
+                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
+             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
+                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
+             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
+                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
+             [-1.166810e-03, 1.098012e-02, -5.897780e-03, -1.562827e-01,
+                 7.282593e-01, -1.562827e-01, -5.897780e-03, 1.098012e-02, -1.166810e-03],
+             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -
+                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],
+             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -
+                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],
+             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,
+                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],
+             [-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor(
+            [6.125880e-03, -8.052600e-03, -2.103714e-02, -1.536890e-02, -1.851466e-02, -1.536890e-02, -2.103714e-02, -8.052600e-03, 6.125880e-03,
+             -1.287416e-02, -9.611520e-03, 1.023569e-02, 6.009450e-03, 1.872620e-03, 6.009450e-03, 1.023569e-02, -
+             9.611520e-03, -1.287416e-02,
+             -5.641530e-03, 4.168400e-03, -2.382180e-02, -5.375324e-02, -
+             2.076086e-02, -5.375324e-02, -2.382180e-02, 4.168400e-03, -5.641530e-03,
+             -8.957260e-03, -1.751170e-03, -1.836909e-02, 1.265655e-01, 2.996168e-01, 1.265655e-01, -
+             1.836909e-02, -1.751170e-03, -8.957260e-03,
+             0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,
+             8.957260e-03, 1.751170e-03, 1.836909e-02, -1.265655e-01, -
+             2.996168e-01, -1.265655e-01, 1.836909e-02, 1.751170e-03, 8.957260e-03,
+             5.641530e-03, -4.168400e-03, 2.382180e-02, 5.375324e-02, 2.076086e-02, 5.375324e-02, 2.382180e-02, -
+             4.168400e-03, 5.641530e-03,
+             1.287416e-02, 9.611520e-03, -1.023569e-02, -6.009450e-03, -
+             1.872620e-03, -6.009450e-03, -1.023569e-02, 9.611520e-03, 1.287416e-02,
+             -6.125880e-03, 8.052600e-03, 2.103714e-02, 1.536890e-02, 1.851466e-02, 1.536890e-02, 2.103714e-02, 8.052600e-03, -6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [-6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03,
+             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
+             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
+             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
+             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
+             1.536890e-02, -6.009450e-03, 5.375324e-02, -
+             1.265655e-01, 0.000000e+00, 1.265655e-01, -
+             5.375324e-02, 6.009450e-03, -1.536890e-02,
+             1.851466e-02, -1.872620e-03, 2.076086e-02, -
+             2.996168e-01, 0.000000e+00, 2.996168e-01, -
+             2.076086e-02, 1.872620e-03, -1.851466e-02,
+             1.536890e-02, -6.009450e-03, 5.375324e-02, -
+             1.265655e-01, 0.000000e+00, 1.265655e-01, -
+             5.375324e-02, 6.009450e-03, -1.536890e-02,
+             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -
+             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,
+             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -
+             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,
+             -6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 4:
+        filters["l"] = torch.tensor([
+            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4,
+                1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5],
+            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
+                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
+            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
+                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
+            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
+                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
+            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
+                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
+            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
+                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
+            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
+                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
+            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
+                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
+            [1.2619999470E-3, 5.5893999524E-3, 5.5381999118E-4, -9.6372198313E-3, -1.6485679895E-2, 1.4385120012E-2, 9.0580143034E-2, 0.1773651242, 0.2120374441,
+                0.1773651242, 9.0580143034E-2, 1.4385120012E-2, -1.6485679895E-2, -9.6372198313E-3, 5.5381999118E-4, 5.5893999524E-3, 1.2619999470E-3],
+            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,
+                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],
+            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,
+                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],
+            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,
+                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],
+            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -
+                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],
+            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -
+                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],
+            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,
+                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],
+            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,
+                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],
+            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4, 1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5]]
+        ).reshape(1, 1, 17, 17)
+        filters["l0"] = torch.tensor([
+            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4,
+                2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5],
+            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
+                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
+            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
+                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
+            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
+                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
+            [2.5240099058E-3, 1.1076199589E-3, -3.2971370965E-2, 0.1811603010, 0.4376249909,
+                0.1811603010, -3.2971370965E-2, 1.1076199589E-3, 2.5240099058E-3],
+            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,
+                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],
+            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -
+                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],
+            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,
+                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],
+            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4, 2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5]]
+        ).reshape(1, 1, 9, 9)
+        filters["h0"] = torch.tensor([
+            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3,
+                2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4],
+            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
+                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
+            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
+                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
+            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
+                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
+            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
+                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
+            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
+                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
+            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
+                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
+            [2.0898699295E-3, 1.9755100366E-3, -8.3368999185E-4, -5.1014497876E-3, 7.3032598011E-3, -9.3812197447E-3, -0.1554141641,
+                0.7303866148, -0.1554141641, -9.3812197447E-3, 7.3032598011E-3, -5.1014497876E-3, -8.3368999185E-4, 1.9755100366E-3, 2.0898699295E-3],
+            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -
+                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],
+            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -
+                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],
+            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,
+                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],
+            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -
+                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],
+            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -
+                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],
+            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,
+                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],
+            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3, 2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4]]
+        ).reshape(1, 1, 15, 15)
+        filters["b"] = []
+        filters["b"].append(torch.tensor(
+            [-8.1125000725E-4, 4.4451598078E-3, 1.2316980399E-2, 1.3955879956E-2,  1.4179450460E-2, 1.3955879956E-2, 1.2316980399E-2, 4.4451598078E-3, -8.1125000725E-4,
+             3.9103501476E-3, 4.4565401040E-3, -5.8724298142E-3, -2.8760801069E-3, 8.5267601535E-3, -
+             2.8760801069E-3, -5.8724298142E-3, 4.4565401040E-3, 3.9103501476E-3,
+             1.3462699717E-3, -3.7740699481E-3, 8.2581602037E-3, 3.9442278445E-2, 5.3605638444E-2, 3.9442278445E-2, 8.2581602037E-3, -
+             3.7740699481E-3, 1.3462699717E-3,
+             7.4700999539E-4, -3.6522001028E-4, -2.2522680461E-2, -0.1105690673, -
+             0.1768419296, -0.1105690673, -2.2522680461E-2, -3.6522001028E-4, 7.4700999539E-4,
+             0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,
+             -7.4700999539E-4, 3.6522001028E-4, 2.2522680461E-2, 0.1105690673, 0.1768419296, 0.1105690673, 2.2522680461E-2, 3.6522001028E-4, -7.4700999539E-4,
+             -1.3462699717E-3, 3.7740699481E-3, -8.2581602037E-3, -3.9442278445E-2, -
+             5.3605638444E-2, -3.9442278445E-2, -
+             8.2581602037E-3, 3.7740699481E-3, -1.3462699717E-3,
+             -3.9103501476E-3, -4.4565401040E-3, 5.8724298142E-3, 2.8760801069E-3, -
+             8.5267601535E-3, 2.8760801069E-3, 5.8724298142E-3, -
+             4.4565401040E-3, -3.9103501476E-3,
+             8.1125000725E-4, -4.4451598078E-3, -1.2316980399E-2, -1.3955879956E-2, -1.4179450460E-2, -1.3955879956E-2, -1.2316980399E-2, -4.4451598078E-3, 8.1125000725E-4]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3,
+             8.2846998703E-4, 0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
+             2.0865600090E-3, 2.3298799060E-3, -
+             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
+             5.7109999034E-5, 9.7479997203E-4, 0.0000000000, -1.2145539746E-2, -
+             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
+             4.4814897701E-3, 1.4807609841E-2,
+             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
+             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
+             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
+             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
+             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, 0.0000000000, -
+             1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
+             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
+             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
+             9.7479997203E-4, -5.7109999034E-5,
+             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
+             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, 0.0000000000, -8.2846998703E-4,
+             3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4,
+             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
+             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
+             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
+             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
+             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
+             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
+             -1.4179450460E-2, -8.5267601535E-3, -5.3605638444E-2, 0.1768419296, 0.0000000000, -
+             0.1768419296, 5.3605638444E-2, 8.5267601535E-3, 1.4179450460E-2,
+             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -
+             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,
+             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -
+             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,
+             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -
+             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,
+             8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000,
+             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -
+             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, -
+             0.0000000000, -8.2846998703E-4,
+             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -
+             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -
+             9.7479997203E-4, -5.7109999034E-5,
+             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, -
+             0.0000000000, -1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,
+             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -
+             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,
+             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -
+             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,
+             5.7109999034E-5, 9.7479997203E-4, -0.0000000000, -1.2145539746E-2, -
+             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -
+             4.4814897701E-3, 1.4807609841E-2,
+             8.2846998703E-4, -0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -
+             2.0865600090E-3, 2.3298799060E-3, -
+             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,
+             0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3]
+        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))
+
+    elif n_orientations == 6:
+        filters["l"] = 2 * torch.tensor([
+            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -
+                0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404],
+            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
+                0.00410600, -0.00661117, -0.00523281, -0.00244917],
+            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
+                0.03277038, 0.01396746, -0.00661117, -0.00387812],
+            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
+                0.06426333, 0.03277038, 0.00410600, -0.00944432],
+            [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 0.10096540,
+                0.08169618, 0.03981393, 0.01002988, -0.00962054],
+            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,
+                0.06426333, 0.03277038, 0.00410600, -0.00944432],
+            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,
+                0.03277038, 0.01396746, -0.00661117, -0.00387812],
+            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,
+                0.00410600, -0.00661117, -0.00523281, -0.00244917],
+            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404]]
+        ).reshape(1, 1, 9, 9)
+        filters["l0"] = torch.tensor([
+            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],
+            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
+            [-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],
+            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],
+            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614]]
+        ).reshape(1, 1, 5, 5)
+        filters["h0"] = torch.tensor([
+            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -
+                0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429],
+            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
+                0.00631653, -0.00243812, -0.00350017, -0.00113093],
+            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
+                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
+            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
+                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
+            [-0.00080639, 0.01261227, -0.00981051, -0.11435863,
+                0.81380200, -0.11435863, -0.00981051, 0.01261227, -0.00080639],
+            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -
+                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],
+            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -
+                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],
+            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,
+                0.00631653, -0.00243812, -0.00350017, -0.00113093],
+            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429]]
+        ).reshape(1, 1, 9, 9)
+        filters["b"] = []
+        filters["b"].append(torch.tensor([
+            0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 0.00496194, 0.00277643,
+            -0.00986904, -0.00893064, 0.01189859, 0.02755155, 0.01189859, -0.00893064, -0.00986904,
+            -0.01021852, -0.03075356, -0.08226445, -
+            0.11732297, -0.08226445, -0.03075356, -0.01021852,
+            0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
+            0.01021852, 0.03075356, 0.08226445, 0.11732297, 0.08226445, 0.03075356, 0.01021852,
+            0.00986904, 0.00893064, -0.01189859, -
+            0.02755155, -0.01189859, 0.00893064, 0.00986904,
+            -0.00277643, -0.00496194, -0.01026699, -0.01455399, -0.01026699, -0.00496194, -0.00277643]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982,
+            -0.00358461, -0.01977507, -0.04084211, -
+            0.00228219, 0.03930573, 0.01161195, 0.00128000,
+            0.01047717, 0.01486305, -0.04819057, -
+            0.12227230, -0.05394139, 0.00853965, -0.00459034,
+            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
+            0.09454202, -0.04435647, -0.00790407,
+            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
+            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
+            -0.01166982, -0.00285723, -0.00182078, -0.01124321, 0.00073141, 0.00640815, 0.00343249]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            0.00343249, 0.00358461, -0.01047717, -
+            0.00790407, -0.00459034, 0.00128000, 0.01166982,
+            0.00640815, 0.01977507, -0.01486305, -
+            0.04435647, 0.00853965, 0.01161195, 0.00285723,
+            0.00073141, 0.04084211, 0.04819057, -
+            0.09454202, -0.05394139, 0.03930573, 0.00182078,
+            -0.01124321, 0.00228219, 0.12227230, -
+            0.00000000, -0.12227230, -0.00228219, 0.01124321,
+            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
+            0.04819057, -0.04084211, -0.00073141,
+            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
+            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor(
+            [-0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643,
+             -0.00496194, 0.00893064, 0.03075356, -
+             0.00000000, -0.03075356, -0.00893064, 0.00496194,
+             -0.01026699, -0.01189859, 0.08226445, -
+             0.00000000, -0.08226445, 0.01189859, 0.01026699,
+             -0.01455399, -0.02755155, 0.11732297, -
+             0.00000000, -0.11732297, 0.02755155, 0.01455399,
+             -0.01026699, -0.01189859, 0.08226445, -
+             0.00000000, -0.08226445, 0.01189859, 0.01026699,
+             -0.00496194, 0.00893064, 0.03075356, -
+             0.00000000, -0.03075356, -0.00893064, 0.00496194,
+             -0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249,
+            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,
+            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -
+            0.04819057, -0.04084211, -0.00073141,
+            -0.01124321, 0.00228219, 0.12227230, -
+            0.00000000, -0.12227230, -0.00228219, 0.01124321,
+            0.00073141, 0.04084211, 0.04819057, -
+            0.09454202, -0.05394139, 0.03930573, 0.00182078,
+            0.00640815, 0.01977507, -0.01486305, -
+            0.04435647, 0.00853965, 0.01161195, 0.00285723,
+            0.00343249, 0.00358461, -0.01047717, -0.00790407, -0.00459034, 0.00128000, 0.01166982]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+        filters["b"].append(torch.tensor([
+            -0.01166982, -0.00285723, -0.00182078, -
+            0.01124321, 0.00073141, 0.00640815, 0.00343249,
+            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,
+            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,
+            0.00790407, 0.04435647, 0.09454202, -0.00000000, -
+            0.09454202, -0.04435647, -0.00790407,
+            0.01047717, 0.01486305, -0.04819057, -
+            0.12227230, -0.05394139, 0.00853965, -0.00459034,
+            -0.00358461, -0.01977507, -0.04084211, -
+            0.00228219, 0.03930573, 0.01161195, 0.00128000,
+            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982]
+        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))
+
+    else:
+        raise Exception(
+            "Steerable filters not implemented for %d orientations" % n_orientations)
+
+    if filter_type == "trained":
+        if size == 5:
+            # TODO maybe also train h0 and l0 filters
+            filters = crop_steerable_pyramid_filters(filters, 5)
+            filters["b"][0] = torch.tensor([
+                [-0.0356752239, -0.0223877281, -0.0009542659,
+                    0.0244821459, 0.0322226137],
+                [-0.0593218654,  0.1245803162, -
+                    0.0023863907, -0.1230178699, 0.0589442067],
+                [-0.0281576272,  0.2976626456, -
+                    0.0020888755, -0.2953369915, 0.0284542721],
+                [-0.0586092323,  0.1251581162, -
+                    0.0024624448, -0.1227868199, 0.0587830991],
+                [-0.0327464789, -0.0223652460, -
+                    0.0042342511,  0.0245472137, 0.0359398536]
+            ]).reshape(1, 1, 5, 5)
+            filters["b"][1] = torch.tensor([
+                [3.9758663625e-02,  6.0679119080e-02,  3.0146904290e-02,
+                    6.1198268086e-02,  3.6218870431e-02],
+                [2.3255519569e-02, -1.2505133450e-01, -
+                    2.9738345742e-01, -1.2518258393e-01,  2.3592948914e-02],
+                [-1.3602430699e-03, -1.2058277935e-04,  2.6399988565e-04, -
+                    2.3791544663e-04,  1.8450465286e-03],
+                [-2.1563466638e-02,  1.2572696805e-01,  2.9745018482e-01,
+                    1.2458638102e-01, -2.3847281933e-02],
+                [-3.7941932678e-02, -6.1060950160e-02, -
+                    2.9489086941e-02, -6.0411967337e-02, -3.8459088653e-02]
+            ]).reshape(1, 1, 5, 5)
+
+            # Below filters were optimised on 09/02/2021
+            # 20K iterations with multiple images at more scales.
+            filters["b"][0] = torch.tensor([
+                [-4.5508436859e-02, -2.1767273545e-02, -1.9399923622e-04,
+                    2.1200872958e-02,  4.5475799590e-02],
+                [-6.3554823399e-02,  1.2832683325e-01, -
+                    5.3858719184e-05, -1.2809979916e-01,  6.3842624426e-02],
+                [-3.4809380770e-02,  2.9954621196e-01,  2.9066693969e-05, -
+                    2.9957753420e-01,  3.4806568176e-02],
+                [-6.3934154809e-02,  1.2806062400e-01,  9.0917674243e-05, -
+                    1.2832444906e-01,  6.3572973013e-02],
+                [-4.5492250472e-02, -2.1125273779e-02,  4.2229349492e-04,
+                    2.1804777905e-02,  4.5236673206e-02]
+            ]).reshape(1, 1, 5, 5)
+            filters["b"][1] = torch.tensor([
+                [4.8947390169e-02,  6.3575074077e-02,  3.4955859184e-02,
+                    6.4085893333e-02,  4.9838040024e-02],
+                [2.2061849013e-02, -1.2936264277e-01, -
+                    3.0093491077e-01, -1.2997294962e-01,  2.0597217605e-02],
+                [-5.1290717238e-05, -1.7305796064e-05,  2.0256420612e-05, -
+                    1.1864109547e-04,  7.3973249528e-05],
+                [-2.0749464631e-02,  1.2988376617e-01,  3.0080935359e-01,
+                    1.2921217084e-01, -2.2159902379e-02],
+                [-4.9614857882e-02, -6.4021714032e-02, -
+                    3.4676689655e-02, -6.3446544111e-02, -4.8282280564e-02]
+            ]).reshape(1, 1, 5, 5)
+
+            # Trained on 17/02/2021 to match fourier pyramid in spatial domain
+            filters["b"][0] = torch.tensor([
+                [3.3370e-02,  9.3934e-02, -3.5810e-04, -9.4038e-02, -3.3115e-02],
+                [1.7716e-01,  3.9378e-01,  6.8461e-05, -3.9343e-01, -1.7685e-01],
+                [2.9213e-01,  6.1042e-01,  7.0654e-04, -6.0939e-01, -2.9177e-01],
+                [1.7684e-01,  3.9392e-01,  1.0517e-03, -3.9268e-01, -1.7668e-01],
+                [3.3000e-02,  9.4029e-02,  7.3565e-04, -9.3366e-02, -3.3008e-02]
+            ]).reshape(1, 1, 5, 5) * 0.1
+
+            filters["b"][1] = torch.tensor([
+                [0.0331,  0.1763,  0.2907,  0.1753,  0.0325],
+                [0.0941,  0.3932,  0.6079,  0.3904,  0.0922],
+                [0.0008,  0.0009, -0.0010, -0.0025, -0.0015],
+                [-0.0929, -0.3919, -0.6097, -0.3944, -0.0946],
+                [-0.0328, -0.1760, -0.2915, -0.1768, -0.0333]
+            ]).reshape(1, 1, 5, 5) * 0.1
+
+        else:
+            raise Exception(
+                "Trained filters not implemented for size %d" % size)
+
+    if filter_type == "cropped":
+        filters = crop_steerable_pyramid_filters(filters, size)
+
+    return filters
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ slice_rgbd_targets(target, depth, depth_plane_positions) + +

+ + +
+ +

Slices the target RGBD image and depth map into multiple layers based on depth plane positions.

+ + +

Parameters:

+
    +
  • + target + – +
    +
                     The RGBD target tensor with shape (C, H, W).
    +
    +
    +
  • +
  • + depth + – +
    +
                     The depth map corresponding to the target image with shape (H, W).
    +
    +
    +
  • +
  • + depth_plane_positions + – +
    +
                     The positions of the depth planes used for slicing.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +targets ( Tensor +) – +
    +

    A tensor of shape (N, C, H, W) where N is the number of depth planes. Contains the sliced targets for each depth plane.

    +
    +
  • +
  • +masks ( Tensor +) – +
    +

    A tensor of shape (N, C, H, W) containing binary masks for each depth plane.

    +
    +
  • +
+ +
+ Source code in odak/learn/perception/util.py +
17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
def slice_rgbd_targets(target, depth, depth_plane_positions):
+    """
+    Slices the target RGBD image and depth map into multiple layers based on depth plane positions.
+
+    Parameters
+    ----------
+    target                 : torch.Tensor
+                             The RGBD target tensor with shape (C, H, W).
+    depth                  : torch.Tensor
+                             The depth map corresponding to the target image with shape (H, W).
+    depth_plane_positions  : list or torch.Tensor
+                             The positions of the depth planes used for slicing.
+
+    Returns
+    -------
+    targets              : torch.Tensor
+                           A tensor of shape (N, C, H, W) where N is the number of depth planes. Contains the sliced targets for each depth plane.
+    masks                : torch.Tensor
+                           A tensor of shape (N, C, H, W) containing binary masks for each depth plane.
+    """
+    device = target.device
+    number_of_planes = len(depth_plane_positions) - 1
+    targets = torch.zeros(
+                        number_of_planes,
+                        target.shape[0],
+                        target.shape[1],
+                        target.shape[2],
+                        requires_grad = False,
+                        device = device
+                        )
+    masks = torch.zeros_like(targets, dtype = torch.int).to(device)
+    mask_zeros = torch.zeros_like(depth, dtype = torch.int)
+    mask_ones = torch.ones_like(depth, dtype = torch.int)
+    for i in range(1, number_of_planes+1):
+        for ch in range(target.shape[0]):
+            pos = depth_plane_positions[i] 
+            prev_pos = depth_plane_positions[i-1] 
+            if i <= (number_of_planes - 1):
+                condition = torch.logical_and(prev_pos <= depth, depth < pos)
+            else:
+                condition = torch.logical_and(prev_pos <= depth, depth <= pos)
+            mask = torch.where(condition, mask_ones, mask_zeros)
+            new_target = target[ch] * mask
+            targets[i-1, ch] = new_target.squeeze(0)
+            masks[i-1, ch] = mask.detach().clone() 
+    return targets, masks
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_raytracing/index.html b/odak/learn_raytracing/index.html new file mode 100644 index 00000000..7427783e --- /dev/null +++ b/odak/learn_raytracing/index.html @@ -0,0 +1,13969 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.raytracing - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.raytracing

+ +
+ + + + +
+ +

odak.learn.raytracing

+

Provides necessary definitions for geometric optics. +See "General Ray tracing procedure" from G.H. Spencerand M.V.R.K Murty for more theoratical explanation.

+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ detector + + +

+ + +
+ + +

A class to represent a detector.

+ + + + + + +
+ Source code in odak/learn/raytracing/detector.py +
class detector():
+    """
+    A class to represent a detector.
+    """
+
+
+    def __init__(
+                 self,
+                 colors = 3,
+                 center = torch.tensor([0., 0., 0.]),
+                 tilt = torch.tensor([0., 0., 0.]),
+                 size = torch.tensor([10., 10.]),
+                 resolution = torch.tensor([100, 100]),
+                 device = torch.device('cpu')
+                ):
+        """
+        Parameters
+        ----------
+        colors         : int
+                         Number of color channels to register (e.g., RGB).
+        center         : torch.tensor
+                         Center point of the detector [3].
+        tilt           : torch.tensor
+                         Tilt angles of the surface in degrees [3].
+        size           : torch.tensor
+                         Size of the detector [2].
+        resolution     : torch.tensor
+                         Resolution of the detector.
+        device         : torch.device
+                         Device for computation (e.g., cuda, cpu).
+        """
+        self.device = device
+        self.colors = colors
+        self.resolution = resolution.to(self.device)
+        self.surface_center = center.to(self.device)
+        self.surface_tilt = tilt.to(self.device)
+        self.size = size.to(self.device)
+        self.pixel_size = torch.tensor([
+                                        self.size[0] / self.resolution[0],
+                                        self.size[1] / self.resolution[1]
+                                       ], device  = self.device)
+        self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)
+        self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.
+        self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)
+        self.plane = define_plane(
+                                  point = self.surface_center,
+                                  angles = self.surface_tilt
+                                 )
+        self.pixel_locations, _, _, _ = grid_sample(
+                                                    size = self.size.tolist(),
+                                                    no = self.resolution.tolist(),
+                                                    center = self.surface_center.tolist(),
+                                                    angles = self.surface_tilt.tolist()
+                                                   )
+        self.pixel_locations = self.pixel_locations.to(self.device)
+        self.relu = torch.nn.ReLU()
+        self.clear()
+
+
+    def intersect(self, rays, color = 0):
+        """
+        Function to intersect rays with the detector
+
+
+        Parameters
+        ----------
+        rays            : torch.tensor
+                          Rays to be intersected with a detector.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+        color           : int
+                          Color channel to register.
+
+        Returns
+        -------
+        points          : torch.tensor
+                          Intersection points with the image detector [k x 3].
+        """
+        normals, _ = intersect_w_surface(rays, self.plane)
+        points = normals[:, 0]
+        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))
+        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))
+        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))
+        hit_x = torch.clamp(distances_x, min = 0., max = 1.)
+        hit_y = torch.clamp(distances_y, min = 0., max = 1.)
+        hit = hit_x * hit_y
+        image = torch.sum(hit, dim = 0)
+        self.image[color] += image.reshape(
+                                           self.image.shape[-2], 
+                                           self.image.shape[-1]
+                                          )
+        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)
+        distance_image = distances
+#        distance_image = distances.reshape(
+#                                           -1,
+#                                           self.image.shape[-2],
+#                                           self.image.shape[-1]
+#                                          )
+        return points, image, distance_image
+
+
+    def get_image(self):
+        """
+        Function to return the detector image.
+
+        Returns
+        -------
+        image           : torch.tensor
+                          Detector image.
+        """
+        image = (self.image - self.image.min()) / (self.image.max() - self.image.min())
+        return image
+
+
+    def clear(self):
+        """
+        Internal function to clear a detector.
+        """
+        self.image = torch.zeros(
+
+                                 self.colors,
+                                 self.resolution[0],
+                                 self.resolution[1],
+                                 device = self.device,
+                                )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(colors=3, center=torch.tensor([0.0, 0.0, 0.0]), tilt=torch.tensor([0.0, 0.0, 0.0]), size=torch.tensor([10.0, 10.0]), resolution=torch.tensor([100, 100]), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + colors + – +
    +
             Number of color channels to register (e.g., RGB).
    +
    +
    +
  • +
  • + center + – +
    +
             Center point of the detector [3].
    +
    +
    +
  • +
  • + tilt + – +
    +
             Tilt angles of the surface in degrees [3].
    +
    +
    +
  • +
  • + size + – +
    +
             Size of the detector [2].
    +
    +
    +
  • +
  • + resolution + – +
    +
             Resolution of the detector.
    +
    +
    +
  • +
  • + device + – +
    +
             Device for computation (e.g., cuda, cpu).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
def __init__(
+             self,
+             colors = 3,
+             center = torch.tensor([0., 0., 0.]),
+             tilt = torch.tensor([0., 0., 0.]),
+             size = torch.tensor([10., 10.]),
+             resolution = torch.tensor([100, 100]),
+             device = torch.device('cpu')
+            ):
+    """
+    Parameters
+    ----------
+    colors         : int
+                     Number of color channels to register (e.g., RGB).
+    center         : torch.tensor
+                     Center point of the detector [3].
+    tilt           : torch.tensor
+                     Tilt angles of the surface in degrees [3].
+    size           : torch.tensor
+                     Size of the detector [2].
+    resolution     : torch.tensor
+                     Resolution of the detector.
+    device         : torch.device
+                     Device for computation (e.g., cuda, cpu).
+    """
+    self.device = device
+    self.colors = colors
+    self.resolution = resolution.to(self.device)
+    self.surface_center = center.to(self.device)
+    self.surface_tilt = tilt.to(self.device)
+    self.size = size.to(self.device)
+    self.pixel_size = torch.tensor([
+                                    self.size[0] / self.resolution[0],
+                                    self.size[1] / self.resolution[1]
+                                   ], device  = self.device)
+    self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)
+    self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.
+    self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)
+    self.plane = define_plane(
+                              point = self.surface_center,
+                              angles = self.surface_tilt
+                             )
+    self.pixel_locations, _, _, _ = grid_sample(
+                                                size = self.size.tolist(),
+                                                no = self.resolution.tolist(),
+                                                center = self.surface_center.tolist(),
+                                                angles = self.surface_tilt.tolist()
+                                               )
+    self.pixel_locations = self.pixel_locations.to(self.device)
+    self.relu = torch.nn.ReLU()
+    self.clear()
+
+
+
+ +
+ +
+ + +

+ clear() + +

+ + +
+ +

Internal function to clear a detector.

+ +
+ Source code in odak/learn/raytracing/detector.py +
def clear(self):
+    """
+    Internal function to clear a detector.
+    """
+    self.image = torch.zeros(
+
+                             self.colors,
+                             self.resolution[0],
+                             self.resolution[1],
+                             device = self.device,
+                            )
+
+
+
+ +
+ +
+ + +

+ get_image() + +

+ + +
+ +

Function to return the detector image.

+ + +

Returns:

+
    +
  • +image ( tensor +) – +
    +

    Detector image.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
def get_image(self):
+    """
+    Function to return the detector image.
+
+    Returns
+    -------
+    image           : torch.tensor
+                      Detector image.
+    """
+    image = (self.image - self.image.min()) / (self.image.max() - self.image.min())
+    return image
+
+
+
+ +
+ +
+ + +

+ intersect(rays, color=0) + +

+ + +
+ +

Function to intersect rays with the detector

+ + +

Parameters:

+
    +
  • + rays + – +
    +
              Rays to be intersected with a detector.
    +          Expected size is [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + color + – +
    +
              Color channel to register.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +points ( tensor +) – +
    +

    Intersection points with the image detector [k x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
    def intersect(self, rays, color = 0):
+        """
+        Function to intersect rays with the detector
+
+
+        Parameters
+        ----------
+        rays            : torch.tensor
+                          Rays to be intersected with a detector.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+        color           : int
+                          Color channel to register.
+
+        Returns
+        -------
+        points          : torch.tensor
+                          Intersection points with the image detector [k x 3].
+        """
+        normals, _ = intersect_w_surface(rays, self.plane)
+        points = normals[:, 0]
+        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))
+        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))
+        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))
+        hit_x = torch.clamp(distances_x, min = 0., max = 1.)
+        hit_y = torch.clamp(distances_y, min = 0., max = 1.)
+        hit = hit_x * hit_y
+        image = torch.sum(hit, dim = 0)
+        self.image[color] += image.reshape(
+                                           self.image.shape[-2], 
+                                           self.image.shape[-1]
+                                          )
+        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)
+        distance_image = distances
+#        distance_image = distances.reshape(
+#                                           -1,
+#                                           self.image.shape[-2],
+#                                           self.image.shape[-1]
+#                                          )
+        return points, image, distance_image
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ planar_mesh + + +

+ + +
+ + + + + + + +
+ Source code in odak/learn/raytracing/mesh.py +
class planar_mesh():
+
+
+    def __init__(
+                 self,
+                 size = [1., 1.],
+                 number_of_meshes = [10, 10],
+                 angles = torch.tensor([0., 0., 0.]),
+                 offset = torch.tensor([0., 0., 0.]),
+                 device = torch.device('cpu'),
+                 heights = None
+                ):
+        """
+        Definition to generate a plane with meshes.
+
+
+        Parameters
+        -----------
+        number_of_meshes  : torch.tensor
+                            Number of squares over plane.
+                            There are two triangles at each square.
+        size              : torch.tensor
+                            Size of the plane.
+        angles            : torch.tensor
+                            Rotation angles in degrees.
+        offset            : torch.tensor
+                            Offset along XYZ axes.
+                            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+        device            : torch.device
+                            Computational resource to be used (e.g., cpu, cuda).
+        heights           : torch.tensor
+                            Load surface heights from a tensor.
+        """
+        self.device = device
+        self.angles = angles.to(self.device)
+        self.offset = offset.to(self.device)
+        self.size = size.to(self.device)
+        self.number_of_meshes = number_of_meshes.to(self.device)
+        self.init_heights(heights)
+
+
+    def init_heights(self, heights = None):
+        """
+        Internal function to initialize a height map.
+        Note that self.heights is a differentiable variable, and can be optimized or learned.
+        See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.
+        """
+        if not isinstance(heights, type(None)):
+            self.heights = heights.to(self.device)
+            self.heights.requires_grad = True
+        else:
+            self.heights = torch.zeros(
+                                       (self.number_of_meshes[0], self.number_of_meshes[1], 1),
+                                       requires_grad = True,
+                                       device = self.device,
+                                      )
+        x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) 
+        y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)
+        X, Y = torch.meshgrid(x, y, indexing = 'ij')
+        self.X = X.unsqueeze(-1)
+        self.Y = Y.unsqueeze(-1)
+
+
+    def save_heights(self, filename = 'heights.pt'):
+        """
+        Function to save heights to a file.
+
+        Parameters
+        ----------
+        filename          : str
+                            Filename.
+        """
+        save_torch_tensor(filename, self.heights.detach().clone())
+
+
+    def save_heights_as_PLY(self, filename = 'mesh.ply'):
+        """
+        Function to save mesh to a PLY file.
+
+        Parameters
+        ----------
+        filename          : str
+                            Filename.
+        """
+        triangles = self.get_triangles()
+        write_PLY(triangles, filename)
+
+
+    def get_squares(self):
+        """
+        Internal function to initiate squares over a plane.
+
+        Returns
+        -------
+        squares     : torch.tensor
+                      Squares over a plane.
+                      Expected size is [m x n x 3].
+        """
+        squares = torch.cat((
+                             self.X,
+                             self.Y,
+                             self.heights
+                            ), dim = -1)
+        return squares
+
+
+    def get_triangles(self):
+        """
+        Internal function to get triangles.
+        """ 
+        squares = self.get_squares()
+        triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)
+        for i in range(0, self.number_of_meshes[0] - 1):
+            for j in range(0, self.number_of_meshes[1] - 1):
+                first_triangle = torch.cat((
+                                            squares[i + 1, j].unsqueeze(0),
+                                            squares[i + 1, j + 1].unsqueeze(0),
+                                            squares[i, j + 1].unsqueeze(0),
+                                           ), dim = 0)
+                second_triangle = torch.cat((
+                                             squares[i + 1, j].unsqueeze(0),
+                                             squares[i, j + 1].unsqueeze(0),
+                                             squares[i, j].unsqueeze(0),
+                                            ), dim = 0)
+                triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)
+                triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)
+        triangles = triangles.view(-1, 3, 3) + self.offset
+        return triangles 
+
+
+    def mirror(self, rays):
+        """
+        Function to bounce light rays off the meshes.
+
+        Parameters
+        ----------
+        rays              : torch.tensor
+                            Rays to be bounced.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+        Returns
+        -------
+        reflected_rays    : torch.tensor
+                            Reflected rays.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+        reflected_normals : torch.tensor
+                            Reflected normals.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+        """
+        if len(rays.shape) == 2:
+            rays = rays.unsqueeze(0)
+        triangles = self.get_triangles()
+        reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+        reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+        for triangle in triangles:
+            _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(
+                                                                                        rays,
+                                                                                        triangle
+                                                                                       ) 
+            triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)
+            if triangle_reflected_rays.shape[0] > 0:
+                reflected_rays = torch.cat((
+                                            reflected_rays,
+                                            triangle_reflected_rays
+                                          ))
+                reflected_normals = torch.cat((
+                                               reflected_normals,
+                                               intersecting_normals
+                                              ))
+        return reflected_rays, reflected_normals
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(size=[1.0, 1.0], number_of_meshes=[10, 10], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([0.0, 0.0, 0.0]), device=torch.device('cpu'), heights=None) + +

+ + +
+ +

Definition to generate a plane with meshes.

+ + +

Parameters:

+
    +
  • + number_of_meshes + – +
    +
                Number of squares over plane.
    +            There are two triangles at each square.
    +
    +
    +
  • +
  • + size + – +
    +
                Size of the plane.
    +
    +
    +
  • +
  • + angles + – +
    +
                Rotation angles in degrees.
    +
    +
    +
  • +
  • + offset + – +
    +
                Offset along XYZ axes.
    +            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
    +            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
    +
    +
    +
  • +
  • + device + – +
    +
                Computational resource to be used (e.g., cpu, cuda).
    +
    +
    +
  • +
  • + heights + – +
    +
                Load surface heights from a tensor.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def __init__(
+             self,
+             size = [1., 1.],
+             number_of_meshes = [10, 10],
+             angles = torch.tensor([0., 0., 0.]),
+             offset = torch.tensor([0., 0., 0.]),
+             device = torch.device('cpu'),
+             heights = None
+            ):
+    """
+    Definition to generate a plane with meshes.
+
+
+    Parameters
+    -----------
+    number_of_meshes  : torch.tensor
+                        Number of squares over plane.
+                        There are two triangles at each square.
+    size              : torch.tensor
+                        Size of the plane.
+    angles            : torch.tensor
+                        Rotation angles in degrees.
+    offset            : torch.tensor
+                        Offset along XYZ axes.
+                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+    device            : torch.device
+                        Computational resource to be used (e.g., cpu, cuda).
+    heights           : torch.tensor
+                        Load surface heights from a tensor.
+    """
+    self.device = device
+    self.angles = angles.to(self.device)
+    self.offset = offset.to(self.device)
+    self.size = size.to(self.device)
+    self.number_of_meshes = number_of_meshes.to(self.device)
+    self.init_heights(heights)
+
+
+
+ +
+ +
+ + +

+ get_squares() + +

+ + +
+ +

Internal function to initiate squares over a plane.

+ + +

Returns:

+
    +
  • +squares ( tensor +) – +
    +

    Squares over a plane. +Expected size is [m x n x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
def get_squares(self):
+    """
+    Internal function to initiate squares over a plane.
+
+    Returns
+    -------
+    squares     : torch.tensor
+                  Squares over a plane.
+                  Expected size is [m x n x 3].
+    """
+    squares = torch.cat((
+                         self.X,
+                         self.Y,
+                         self.heights
+                        ), dim = -1)
+    return squares
+
+
+
+ +
+ +
+ + +

+ get_triangles() + +

+ + +
+ +

Internal function to get triangles.

+ +
+ Source code in odak/learn/raytracing/mesh.py +
def get_triangles(self):
+    """
+    Internal function to get triangles.
+    """ 
+    squares = self.get_squares()
+    triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)
+    for i in range(0, self.number_of_meshes[0] - 1):
+        for j in range(0, self.number_of_meshes[1] - 1):
+            first_triangle = torch.cat((
+                                        squares[i + 1, j].unsqueeze(0),
+                                        squares[i + 1, j + 1].unsqueeze(0),
+                                        squares[i, j + 1].unsqueeze(0),
+                                       ), dim = 0)
+            second_triangle = torch.cat((
+                                         squares[i + 1, j].unsqueeze(0),
+                                         squares[i, j + 1].unsqueeze(0),
+                                         squares[i, j].unsqueeze(0),
+                                        ), dim = 0)
+            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)
+            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)
+    triangles = triangles.view(-1, 3, 3) + self.offset
+    return triangles 
+
+
+
+ +
+ +
+ + +

+ init_heights(heights=None) + +

+ + +
+ +

Internal function to initialize a height map. +Note that self.heights is a differentiable variable, and can be optimized or learned. +See unit test test/test_learn_ray_detector.py or test/test_learn_ray_mesh.py as examples.

+ +
+ Source code in odak/learn/raytracing/mesh.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def init_heights(self, heights = None):
+    """
+    Internal function to initialize a height map.
+    Note that self.heights is a differentiable variable, and can be optimized or learned.
+    See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.
+    """
+    if not isinstance(heights, type(None)):
+        self.heights = heights.to(self.device)
+        self.heights.requires_grad = True
+    else:
+        self.heights = torch.zeros(
+                                   (self.number_of_meshes[0], self.number_of_meshes[1], 1),
+                                   requires_grad = True,
+                                   device = self.device,
+                                  )
+    x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) 
+    y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)
+    X, Y = torch.meshgrid(x, y, indexing = 'ij')
+    self.X = X.unsqueeze(-1)
+    self.Y = Y.unsqueeze(-1)
+
+
+
+ +
+ +
+ + +

+ mirror(rays) + +

+ + +
+ +

Function to bounce light rays off the meshes.

+ + +

Parameters:

+
    +
  • + rays + – +
    +
                Rays to be bounced.
    +            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +reflected_rays ( tensor +) – +
    +

    Reflected rays. +Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
  • +reflected_normals ( tensor +) – +
    +

    Reflected normals. +Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
def mirror(self, rays):
+    """
+    Function to bounce light rays off the meshes.
+
+    Parameters
+    ----------
+    rays              : torch.tensor
+                        Rays to be bounced.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    Returns
+    -------
+    reflected_rays    : torch.tensor
+                        Reflected rays.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+    reflected_normals : torch.tensor
+                        Reflected normals.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    """
+    if len(rays.shape) == 2:
+        rays = rays.unsqueeze(0)
+    triangles = self.get_triangles()
+    reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+    reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+    for triangle in triangles:
+        _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(
+                                                                                    rays,
+                                                                                    triangle
+                                                                                   ) 
+        triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)
+        if triangle_reflected_rays.shape[0] > 0:
+            reflected_rays = torch.cat((
+                                        reflected_rays,
+                                        triangle_reflected_rays
+                                      ))
+            reflected_normals = torch.cat((
+                                           reflected_normals,
+                                           intersecting_normals
+                                          ))
+    return reflected_rays, reflected_normals
+
+
+
+ +
+ +
+ + +

+ save_heights(filename='heights.pt') + +

+ + +
+ +

Function to save heights to a file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
                Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
def save_heights(self, filename = 'heights.pt'):
+    """
+    Function to save heights to a file.
+
+    Parameters
+    ----------
+    filename          : str
+                        Filename.
+    """
+    save_torch_tensor(filename, self.heights.detach().clone())
+
+
+
+ +
+ +
+ + +

+ save_heights_as_PLY(filename='mesh.ply') + +

+ + +
+ +

Function to save mesh to a PLY file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
                Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
def save_heights_as_PLY(self, filename = 'mesh.ply'):
+    """
+    Function to save mesh to a PLY file.
+
+    Parameters
+    ----------
+    filename          : str
+                        Filename.
+    """
+    triangles = self.get_triangles()
+    write_PLY(triangles, filename)
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ center_of_triangle(triangle) + +

+ + +
+ +

Definition to calculate center of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
            An array that contains three points defining a triangle (Mx3). 
    +        It can also parallel process many triangles (NxMx3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +centers ( tensor +) – +
    +

    Triangle centers.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def center_of_triangle(triangle):
+    """
+    Definition to calculate center of a triangle.
+
+    Parameters
+    ----------
+    triangle      : torch.tensor
+                    An array that contains three points defining a triangle (Mx3). 
+                    It can also parallel process many triangles (NxMx3).
+
+    Returns
+    -------
+    centers       : torch.tensor
+                    Triangle centers.
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.view((1, 3, 3))
+    center = torch.mean(triangle, axis=1)
+    return center
+
+
+
+ +
+ +
+ + +

+ create_ray(xyz, abg, direction=False) + +

+ + +
+ +

Definition to create a ray.

+ + +

Parameters:

+
    +
  • + xyz + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + abg + – +
    +
           List that contains angles in degrees with respect to the X,Y and Z axes.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + direction + – +
    +
           If set to True, cosines of `abg` is not calculated.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray. +Size will be either [1 x 3] or [m x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
def create_ray(xyz, abg, direction = False):
+    """
+    Definition to create a ray.
+
+    Parameters
+    ----------
+    xyz          : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    abg          : torch.tensor
+                   List that contains angles in degrees with respect to the X,Y and Z axes.
+                   Size could be [1 x 3], [3], [m x 3].
+    direction    : bool
+                   If set to True, cosines of `abg` is not calculated.
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray.
+                   Size will be either [1 x 3] or [m x 3].
+    """
+    points = xyz
+    angles = abg
+    if len(xyz) == 1:
+        points = xyz.unsqueeze(0)
+    if len(abg) == 1:
+        angles = abg.unsqueeze(0)
+    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)
+    ray[:, 0] = points
+    if direction:
+        ray[:, 1] = abg
+    else:
+        ray[:, 1] = torch.cos(torch.deg2rad(abg))
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_all_pairs(x0y0z0, x1y1z1) + +

+ + +
+ +

Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           Tensor that contains X, Y, and Z start locations of rays.
    +       Size should be [m x 3].
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           Tensor that contains X, Y, and Z end locations of rays.
    +       Size should be [n x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rays ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_all_pairs(x0y0z0, x1y1z1):
+    """
+    Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.
+
+    Parameters
+    ----------
+    x0y0z0       : torch.tensor
+                   Tensor that contains X, Y, and Z start locations of rays.
+                   Size should be [m x 3].
+    x1y1z1       : torch.tensor
+                   Tensor that contains X, Y, and Z end locations of rays.
+                   Size should be [n x 3].
+
+    Returns
+    ----------
+    rays         : torch.tensor
+                   Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]
+    """
+
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.unsqueeze(0)
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.unsqueeze(0)
+
+    m, n = x0y0z0.shape[0], x1y1z1.shape[0]
+    start_points = x0y0z0.unsqueeze(1).expand(-1, n, -1).reshape(-1, 3)
+    end_points = x1y1z1.unsqueeze(0).expand(m, -1, -1).reshape(-1, 3)
+
+    directions = end_points - start_points
+    norms = torch.norm(directions, p=2, dim=1, keepdim=True)
+    norms[norms == 0] = float('nan')
+
+    normalized_directions = directions / norms
+
+    rays = torch.zeros(m * n, 2, 3, device=x0y0z0.device)
+    rays[:, 0, :] = start_points
+    rays[:, 1, :] = normalized_directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit) + +

+ + +
+ +

Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.

+ + +
+ Parameters: +

center : torch.tensor + The center point of the light array, shape [3]. +size : list[int] + The size of the light array [height, width] +no : list[int] + The number of the light arary [number of lights in height , number of lights inwidth] +tilt : torch.tensor + The tilt angles in degrees along x, y, z axes for the rays, shape [3]. +angle_limit : float + The maximum angle in degrees from the initial direction vector within which to emit rays. +num_rays_per_light : int + The number of rays each light should emit.

+
+ +
+ Returns: +

rays : torch.tensor + Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

+
+
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit):
+    """
+    Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.
+
+    Parameters:
+    ----------
+    center              : torch.tensor
+                          The center point of the light array, shape [3].
+    size                : list[int]
+                          The size of the light array [height, width]
+    no                  : list[int]
+                          The number of the light arary [number of lights in height , number of lights inwidth]
+    tilt                : torch.tensor
+                          The tilt angles in degrees along x, y, z axes for the rays, shape [3].
+    angle_limit         : float
+                          The maximum angle in degrees from the initial direction vector within which to emit rays.
+    num_rays_per_light  : int
+                          The number of rays each light should emit.
+
+    Returns:
+    ----------
+    rays : torch.tensor
+           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]
+    """
+
+    samples = torch.zeros((no[0], no[1], 3))
+
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((no[0]*no[1], 3))
+
+    samples, *_ = rotate_points(samples, angles=tilt)
+
+    samples = samples + center
+    angle_limit = torch.as_tensor(angle_limit)
+    cos_alpha = torch.cos(angle_limit * torch.pi / 180)
+    tilt = tilt * torch.pi / 180
+
+    theta = torch.acos(1 - 2 * torch.rand(num_ray_per_light*samples.size(0)) * (1-cos_alpha))
+    phi = 2 * torch.pi * torch.rand(num_ray_per_light*samples.size(0))  
+
+    directions = torch.stack([
+        torch.sin(theta) * torch.cos(phi),  
+        torch.sin(theta) * torch.sin(phi),  
+        torch.cos(theta)                    
+    ], dim=1)
+
+    c, s = torch.cos(tilt), torch.sin(tilt)
+
+    Rx = torch.tensor([
+        [1, 0, 0],
+        [0, c[0], -s[0]],
+        [0, s[0], c[0]]
+    ])
+
+    Ry = torch.tensor([
+        [c[1], 0, s[1]],
+        [0, 1, 0],
+        [-s[1], 0, c[1]]
+    ])
+
+    Rz = torch.tensor([
+        [c[2], -s[2], 0],
+        [s[2], c[2], 0],
+        [0, 0, 1]
+    ])
+
+    origins = samples.repeat(num_ray_per_light, 1)
+
+    directions = torch.matmul(directions, (Rz@Ry@Rx).T)
+
+
+    rays = torch.zeros(num_ray_per_light*samples.size(0), 2, 3)
+    rays[:, 0, :] = origins
+    rays[:, 1, :] = directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit) + +

+ + +
+ +

Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.

+ + +
+ Parameters: +

origin : torch.tensor + The origin point of the rays, shape [3]. +num_rays : int + The total number of rays to generate. +tilt : torch.tensor + The tilt angles in degrees along x, y, z axes, shape [3]. +angle_limit : float + The maximum angle in degrees from the initial direction vector within which to emit rays.

+
+ +
+ Returns: +

rays : torch.tensor + Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

+
+
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit):
+    """
+    Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.
+
+    Parameters:
+    ----------
+    origin      : torch.tensor
+                  The origin point of the rays, shape [3].
+    num_rays    : int
+                  The total number of rays to generate.
+    tilt        : torch.tensor
+                  The tilt angles in degrees along x, y, z axes, shape [3].
+    angle_limit : float
+                  The maximum angle in degrees from the initial direction vector within which to emit rays.
+
+    Returns:
+    ----------
+    rays : torch.tensor
+           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]
+    """
+    angle_limit = torch.as_tensor(angle_limit) 
+    cos_alpha = torch.cos(angle_limit * torch.pi / 180)
+    tilt = tilt * torch.pi / 180
+
+    theta = torch.acos(1 - 2 * torch.rand(num_ray) * (1-cos_alpha))
+    phi = 2 * torch.pi * torch.rand(num_ray)  
+
+
+    directions = torch.stack([
+        torch.sin(theta) * torch.cos(phi),  
+        torch.sin(theta) * torch.sin(phi),  
+        torch.cos(theta)                    
+    ], dim=1)
+
+    c, s = torch.cos(tilt), torch.sin(tilt)
+
+    Rx = torch.tensor([
+        [1, 0, 0],
+        [0, c[0], -s[0]],
+        [0, s[0], c[0]]
+    ])
+
+    Ry = torch.tensor([
+        [c[1], 0, s[1]],
+        [0, 1, 0],
+        [-s[1], 0, c[1]]
+    ])
+
+    Rz = torch.tensor([
+        [c[2], -s[2], 0],
+        [s[2], c[2], 0],
+        [0, 0, 1]
+    ])
+
+    origins = origin.repeat(num_ray, 1)
+    directions = torch.matmul(directions, (Rz@Ry@Rx).T)
+
+
+    rays = torch.zeros(num_ray, 2, 3)
+    rays[:, 0, :] = origins
+    rays[:, 1, :] = directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_two_points(x0y0z0, x1y1z1) + +

+ + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray or batch of rays.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray(s).

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    x1y1z1       : torch.tensor
+                   List that contains X,Y and Z ending locations of a ray or batch of rays.
+                   Size could be [1 x 3], [3], [m x 3].
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray(s).
+    """
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.unsqueeze(0)
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.unsqueeze(0)
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5
+    s[s == 0] = float('nan')
+    cosines = torch.zeros_like(x0y0z0 * x1y1z1)
+    cosines[:, 0] = xdiff / s
+    cosines[:, 1] = ydiff / s
+    cosines[:, 2] = zdiff / s
+    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    return ray
+
+
+
+ +
+ +
+ + +

+ define_circle(center, radius, angles) + +

+ + +
+ +

Definition to describe a circle in a single variable packed form.

+ + +

Parameters:

+
    +
  • + center + – +
    +
      Center of a circle to be defined in 3D space.
    +
    +
    +
  • +
  • + radius + – +
    +
      Radius of a circle to be defined.
    +
    +
    +
  • +
  • + angles + – +
    +
      Angular tilt of a circle represented by rotations about x, y, and z axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +circle ( list +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def define_circle(center, radius, angles):
+    """
+    Definition to describe a circle in a single variable packed form.
+
+    Parameters
+    ----------
+    center  : torch.Tensor
+              Center of a circle to be defined in 3D space.
+    radius  : float
+              Radius of a circle to be defined.
+    angles  : torch.Tensor
+              Angular tilt of a circle represented by rotations about x, y, and z axes.
+
+    Returns
+    ----------
+    circle  : list
+              Single variable packed form.
+    """
+    points = define_plane(center, angles=angles)
+    circle = [
+        points,
+        center,
+        torch.tensor([radius])
+    ]
+    return circle
+
+
+
+ +
+ +
+ + +

+ define_plane(point, angles=torch.tensor([0.0, 0.0, 0.0])) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point that is at the center of a plane.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +plane ( tensor +) – +
    +

    Points defining plane.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
def define_plane(point, angles = torch.tensor([0., 0., 0.])):
+    """ 
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    point        : torch.tensor
+                   A point that is at the center of a plane.
+    angles       : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    plane        : torch.tensor
+                   Points defining plane.
+    """
+    plane = torch.tensor([
+                          [10., 10., 0.],
+                          [0., 10., 0.],
+                          [0.,  0., 0.]
+                         ], device = point.device)
+    for i in range(0, plane.shape[0]):
+        plane[i], _, _, _ = rotate_points(plane[i], angles = angles.to(point.device))
+        plane[i] = plane[i] + point
+    return plane
+
+
+
+ +
+ +
+ + +

+ define_plane_mesh(number_of_meshes=[10, 10], size=[1.0, 1.0], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([[0.0, 0.0, 0.0]])) + +

+ + +
+ +

Definition to generate a plane with meshes.

+ + +

Parameters:

+
    +
  • + number_of_meshes + – +
    +
                Number of squares over plane.
    +            There are two triangles at each square.
    +
    +
    +
  • +
  • + size + – +
    +
                Size of the plane.
    +
    +
    +
  • +
  • + angles + – +
    +
                Rotation angles in degrees.
    +
    +
    +
  • +
  • + offset + – +
    +
                Offset along XYZ axes.
    +            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
    +            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +triangles ( tensor +) – +
    +

    Triangles [m x 3 x 3], where m is 2 * number_of_meshes[0] times number_of_meshes[1].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def define_plane_mesh(
+                      number_of_meshes = [10, 10], 
+                      size = [1., 1.], 
+                      angles = torch.tensor([0., 0., 0.]), 
+                      offset = torch.tensor([[0., 0., 0.]])
+                     ):
+    """
+    Definition to generate a plane with meshes.
+
+
+    Parameters
+    -----------
+    number_of_meshes  : torch.tensor
+                        Number of squares over plane.
+                        There are two triangles at each square.
+    size              : list
+                        Size of the plane.
+    angles            : torch.tensor
+                        Rotation angles in degrees.
+    offset            : torch.tensor
+                        Offset along XYZ axes.
+                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`. 
+
+    Returns
+    -------
+    triangles         : torch.tensor
+                        Triangles [m x 3 x 3], where m is `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+    """
+    triangles = torch.zeros(2, number_of_meshes[0], number_of_meshes[1], 3, 3)
+    step = [size[0] / number_of_meshes[0], size[1] / number_of_meshes[1]]
+    for i in range(0, number_of_meshes[0] - 1):
+        for j in range(0, number_of_meshes[1] - 1):
+            first_triangle = torch.tensor([
+                                           [       -size[0] / 2. + step[0] * i,       -size[1] / 2. + step[0] * j, 0.],
+                                           [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],
+                                           [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]
+                                          ])
+            second_triangle = torch.tensor([
+                                            [ -size[0] / 2. + step[0] * (i + 1), -size[1] / 2. + step[0] * (j + 1), 0.],
+                                            [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],
+                                            [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]
+                                           ])
+            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = angles)
+            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = angles)
+    triangles = triangles.view(-1, 3, 3) + offset
+    return triangles
+
+
+
+ +
+ +
+ + +

+ define_sphere(center=torch.tensor([[0.0, 0.0, 0.0]]), radius=torch.tensor([1.0])) + +

+ + +
+ +

Definition to define a sphere.

+ + +

Parameters:

+
    +
  • + center + – +
    +
          Center of the sphere(s) along XYZ axes.
    +      Expected size is [3], [1, 3] or [m, 3].
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of that sphere(s).
    +      Expected size is [1], [1, 1], [m] or [m, 1].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +parameters ( tensor +) – +
    +

    Parameters of defined sphere(s). +Expected size is [1, 3] or [m x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def define_sphere(center = torch.tensor([[0., 0., 0.]]), radius = torch.tensor([1.])):
+    """
+    Definition to define a sphere.
+
+    Parameters
+    ----------
+    center      : torch.tensor
+                  Center of the sphere(s) along XYZ axes.
+                  Expected size is [3], [1, 3] or [m, 3].
+    radius      : torch.tensor
+                  Radius of that sphere(s).
+                  Expected size is [1], [1, 1], [m] or [m, 1].
+
+    Returns
+    -------
+    parameters  : torch.tensor
+                  Parameters of defined sphere(s).
+                  Expected size is [1, 3] or [m x 3].
+    """
+    if len(radius.shape) == 1:
+        radius = radius.unsqueeze(0)
+    if len(center.shape) == 1:
+        center = center.unsqueeze(1)
+    parameters = torch.cat((center, radius), dim = 1)
+    return parameters
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( Tensor +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : torch.Tensor
+                  First point in X,Y,Z.
+    point2      : torch.Tensor
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : torch.Tensor
+                  Distance in between given two points.
+    """
+    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1
+    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2
+
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))
+
+    return distance
+
+
+
+ +
+ +
+ + +

+ get_sphere_normal_torch(point, sphere) + +

+ + +
+ +

Definition to get a normal of a point on a given sphere.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Point on sphere in X,Y,Z.
    +
    +
    +
  • +
  • + sphere + – +
    +
            Center defined in X,Y,Z and radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( tensor +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def get_sphere_normal_torch(point, sphere):
+    """
+    Definition to get a normal of a point on a given sphere.
+
+    Parameters
+    ----------
+    point         : torch.tensor
+                    Point on sphere in X,Y,Z.
+    sphere        : torch.tensor
+                    Center defined in X,Y,Z and radius.
+
+    Returns
+    ----------
+    normal_vector : torch.tensor
+                    Normal vector.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    normal_vector = create_ray_from_two_points(point, sphere[0:3])
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_triangle_normal(triangle, triangle_center=None) + +

+ + +
+ +

Definition to calculate surface normal of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
    +
    +
    +
  • +
  • + triangle_center + (tensor, default: + None +) + – +
    +
              Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def get_triangle_normal(triangle, triangle_center=None):
+    """
+    Definition to calculate surface normal of a triangle.
+
+    Parameters
+    ----------
+    triangle        : torch.tensor
+                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
+    triangle_center : torch.tensor
+                      Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
+
+    Returns
+    ----------
+    normal          : torch.tensor
+                      Surface normal at the point of intersection.
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.view((1, 3, 3))
+    normal = torch.zeros((triangle.shape[0], 2, 3)).to(triangle.device)
+    direction = torch.linalg.cross(
+                                   triangle[:, 0] - triangle[:, 1], 
+                                   triangle[:, 2] - triangle[:, 1]
+                                  )
+    if type(triangle_center) == type(None):
+        normal[:, 0] = center_of_triangle(triangle)
+    else:
+        normal[:, 0] = triangle_center
+    normal[:, 1] = direction / torch.sum(direction, axis=1)[0]
+    if normal.shape[0] == 1:
+        normal = normal.view((2, 3))
+    return normal
+
+
+
+ +
+ +
+ + +

+ grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( tensor +) – +
    +

    Samples generated.

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix at X axis.

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix at Y axis.

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix at Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/sample.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
def grid_sample(
+                no = [10, 10],
+                size = [100., 100.], 
+                center = [0., 0., 0.], 
+                angles = [0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    -------
+    samples     : torch.tensor
+                  Samples generated.
+    rotx        : torch.tensor
+                  Rotation matrix at X axis.
+    roty        : torch.tensor
+                  Rotation matrix at Y axis.
+    rotz        : torch.tensor
+                  Rotation matrix at Z axis.
+    """
+    center = torch.tensor(center)
+    angles = torch.tensor(angles)
+    size = torch.tensor(size)
+    samples = torch.zeros((no[0], no[1], 3))
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))
+    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)
+    return samples, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ intersect_w_circle(ray, circle) + +

+ + +
+ +

Definition to find intersection point of a ray with a circle. +Returns distance as zero if there isn't an intersection.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + circle + – +
    +
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( Tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( Tensor +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_circle(ray, circle):
+    """
+    Definition to find intersection point of a ray with a circle. 
+    Returns distance as zero if there isn't an intersection.
+
+    Parameters
+    ----------
+    ray          : torch.Tensor
+                   A vector/ray.
+    circle       : list
+                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
+
+    Returns
+    ----------
+    normal       : torch.Tensor
+                   Surface normal at the point of intersection.
+    distance     : torch.Tensor
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, circle[0])
+
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+
+    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])
+    mask = distance_to_center > circle[2]
+    distance[mask] = 0
+
+    if len(ray.shape) == 2:
+        normal = normal.squeeze(0)
+
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_sphere(ray, sphere, learning_rate=0.2, number_of_steps=5000, error_threshold=0.01) + +

+ + +
+ +

Definition to find the intersection between ray(s) and sphere(s).

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  Input ray(s).
    +              Expected size is [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + sphere + – +
    +
                  Input sphere.
    +              Expected size is [1 x 4].
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                  Learning rate used in the optimizer for finding the propagation distances of the rays.
    +
    +
    +
  • +
  • + number_of_steps + – +
    +
                  Number of steps used in the optimizer.
    +
    +
    +
  • +
  • + error_threshold + – +
    +
                  The error threshold that will help deciding intersection or no intersection.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +intersecting_ray ( tensor +) – +
    +

    Ray(s) that intersecting with the given sphere. +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normal(s) for the ray(s) intersecting with the given sphere +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):
+    """
+    Definition to find the intersection between ray(s) and sphere(s).
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          Input ray(s).
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    sphere              : torch.tensor
+                          Input sphere.
+                          Expected size is [1 x 4].
+    learning_rate       : float
+                          Learning rate used in the optimizer for finding the propagation distances of the rays.
+    number_of_steps     : int
+                          Number of steps used in the optimizer.
+    error_threshold     : float
+                          The error threshold that will help deciding intersection or no intersection.
+
+    Returns
+    -------
+    intersecting_ray    : torch.tensor
+                          Ray(s) that intersecting with the given sphere.
+                          Expected size is [n x 2 x 3], where n could be any real number.
+    intersecting_normal : torch.tensor
+                          Normal(s) for the ray(s) intersecting with the given sphere
+                          Expected size is [n x 2 x 3], where n could be any real number.
+
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(sphere.shape) == 1:
+        sphere = sphere.unsqueeze(0)
+    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)
+    loss_l2 = torch.nn.MSELoss(reduction = 'sum')
+    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    
+    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)
+    for step in t:
+        optimizer.zero_grad()
+        propagated_ray = propagate_ray(ray, distance)
+        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)
+        loss = loss_l2(
+                       test,
+                       torch.zeros_like(test)
+                      )
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description('Sphere intersection loss: {}'.format(loss.item()))
+    check = test < error_threshold
+    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])
+    intersecting_normal = create_ray_from_two_points(
+                                                     sphere[:, 0:3],
+                                                     intersecting_ray[:, 0]
+                                                    )
+    return intersecting_ray, intersecting_normal, distance, check
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface(ray, points) + +

+ + +
+ +

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + points + – +
    +
           Set of points in X,Y and Z to define a planar surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_surface(ray, points):
+    """
+    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray.
+    points       : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface.
+
+    Returns
+    ----------
+    normal       : torch.tensor
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between starting point of a ray with it's intersection with a planar surface.
+    """
+    normal = get_triangle_normal(points)
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(points.shape) == 2:
+        points = points.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    f = normal[:, 0] - ray[:, 0]
+    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T
+    new_normal = torch.zeros_like(ray)
+    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]
+    new_normal[:, 1] = normal[:, 1]
+    new_normal = torch.nan_to_num(
+                                  new_normal,
+                                  nan = float('nan'),
+                                  posinf = float('nan'),
+                                  neginf = float('nan')
+                                 )
+    distance = torch.nan_to_num(
+                                distance,
+                                nan = float('nan'),
+                                posinf = float('nan'),
+                                neginf = float('nan')
+                               )
+    return new_normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface_batch(ray, triangle) + +

+ + +
+ + + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection (m x n x 2 x 3).

    +
    +
  • +
  • +distance ( tensor +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_surface_batch(ray, triangle):
+    """
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal       : torch.tensor
+                   Surface normal at the point of intersection (m x n x 2 x 3).
+    distance     : torch.tensor
+                   Distance in between starting point of a ray with it's intersection with a planar surface (m x n).
+    """
+    normal = get_triangle_normal(triangle)
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+
+    f = normal[:, None, 0] - ray[None, :, 0]
+    distance = (torch.bmm(normal[:, None, 1], f.permute(0, 2, 1)).squeeze(1) / torch.mm(normal[:, 1], ray[:, 1].T)).T
+
+    new_normal = torch.zeros((triangle.shape[0], )+ray.shape)
+    new_normal[:, :, 0] = ray[None, :, 0] + (distance[:, :, None] * ray[:, None, 1]).permute(1, 0, 2)
+    new_normal[:, :, 1] = normal[:, None, 1]
+    new_normal = torch.nan_to_num(
+                                  new_normal,
+                                  nan = float('nan'),
+                                  posinf = float('nan'),
+                                  neginf = float('nan')
+                                 )
+    distance = torch.nan_to_num(
+                                distance,
+                                nan = float('nan'),
+                                posinf = float('nan'),
+                                neginf = float('nan')
+                               )
+    return new_normal, distance.T
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle(ray, triangle) + +

+ + +
+ +

Definition to find intersection point of a ray with a triangle.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection with the surface of triangle. +This could also involve surface normals that are not on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle. +Expected size is [1 x 1] or [m x 1] depending on the input.

    +
    +
  • +
  • +intersecting_ray ( tensor +) – +
    +

    Rays that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normals that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +check ( tensor +) – +
    +

    A list that provides a bool as True or False for each ray used as input. +A test to see is a ray could be on the given triangle. +Expected size is [1] or [m].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_triangle(ray, triangle):
+    """
+    Definition to find intersection point of a ray with a triangle. 
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
+    triangle            : torch.tensor
+                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
+
+    Returns
+    ----------
+    normal              : torch.tensor
+                          Surface normal at the point of intersection with the surface of triangle.
+                          This could also involve surface normals that are not on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    distance            : float
+                          Distance in between a starting point of a ray and the intersection point with a given triangle.
+                          Expected size is [1 x 1] or [m x 1] depending on the input.
+    intersecting_ray    : torch.tensor
+                          Rays that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    intersecting_normal : torch.tensor
+                          Normals that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    check               : torch.tensor
+                          A list that provides a bool as True or False for each ray used as input.
+                          A test to see is a ray could be on the given triangle.
+                          Expected size is [1] or [m].
+    """
+    if len(triangle.shape) == 2:
+       triangle = triangle.unsqueeze(0)
+    if len(ray.shape) == 2:
+       ray = ray.unsqueeze(0)
+    normal, distance = intersect_w_surface(ray, triangle)
+    check = is_it_on_triangle(normal[:, 0], triangle)
+    intersecting_ray = ray.unsqueeze(0)
+    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_ray = intersecting_ray[check == True]
+    intersecting_normal = normal.unsqueeze(0)
+    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_normal = intersecting_normal[check ==  True]
+    return normal, distance, intersecting_ray, intersecting_normal, check
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle_batch(ray, triangle) + +

+ + +
+ +

Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           vectors/rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection (m x n x 2 x 3).

    +
    +
  • +
  • +distance ( List +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

    +
    +
  • +
  • +intersect_ray ( List +) – +
    +

    List of intersecting rays (k x 2 x 3) where k <= n.

    +
    +
  • +
  • +intersect_normal ( List +) – +
    +

    List of intersecting normals (k x 2 x 3) where k <= n*m.

    +
    +
  • +
  • +check ( tensor +) – +
    +

    Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_triangle_batch(ray, triangle):
+    """
+    Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   vectors/rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal          : torch.tensor
+                      Surface normal at the point of intersection (m x n x 2 x 3).
+    distance        : List
+                      Distance in between starting point of a ray with it's intersection with a planar surface (m x n).
+    intersect_ray   : List
+                      List of intersecting rays (k x 2 x 3) where k <= n.
+    intersect_normal: List
+                      List of intersecting normals (k x 2 x 3) where k <= n*m.
+    check           : torch.tensor
+                      Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.
+    """
+    if len(triangle.shape) == 2:
+       triangle = triangle.unsqueeze(0)
+    if len(ray.shape) == 2:
+       ray = ray.unsqueeze(0)
+
+    normal, distance = intersect_w_surface_batch(ray, triangle)
+
+    check = is_it_on_triangle_batch(normal[:, :, 0], triangle)
+
+    flat_check = check.flatten()
+    flat_normal = normal.view(-1, normal.size(-2), normal.size(-1))
+    flat_ray = ray.repeat(normal.size(0), 1, 1)
+    flat_distance = distance.flatten()
+
+    filtered_normal = torch.masked_select(flat_normal, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))
+    filtered_ray = torch.masked_select(flat_ray, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))
+    filtered_distnace = torch.masked_select(flat_distance, flat_check)
+
+    check_count = check.sum(dim=1).tolist()
+    split_size_ray_and_normal = [count * 2 * 3 for count in check_count]
+    split_size_distance = [count for count in check_count]
+
+    normal_grouped = torch.split(filtered_normal, split_size_ray_and_normal)
+    ray_grouped = torch.split(filtered_ray, split_size_ray_and_normal)
+    distance_grouped = torch.split(filtered_distnace, split_size_distance)
+
+    intersecting_normal = [g.view(-1, 2, 3) for g in normal_grouped if g.numel() > 0]
+    intersecting_ray = [g.view(-1, 2, 3) for g in ray_grouped if g.numel() > 0]
+    new_distance = [g for g in distance_grouped if g.numel() > 0]
+
+    return normal, new_distance, intersecting_ray, intersecting_normal, check
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle(point_to_check, triangle) + +

+ + +
+ +

Definition to check if a given point is inside a triangle. +If the given point is inside a defined triangle, this definition returns True. +For more details, visit: https://blackpawn.com/texts/pointinpoly/.

+ + +

Parameters:

+
    +
  • + point_to_check + – +
    +
              Point(s) to check.
    +          Expected size is [3], [1 x 3] or [m x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
              Triangle described with three points.
    +          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Is it on a triangle? Returns NaN if condition not satisfied. +Expected size is [1] or [m] depending on the input.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def is_it_on_triangle(point_to_check, triangle):
+    """
+    Definition to check if a given point is inside a triangle. 
+    If the given point is inside a defined triangle, this definition returns True.
+    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).
+
+    Parameters
+    ----------
+    point_to_check  : torch.tensor
+                      Point(s) to check.
+                      Expected size is [3], [1 x 3] or [m x 3].
+    triangle        : torch.tensor
+                      Triangle described with three points.
+                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
+
+    Returns
+    -------
+    result          : torch.tensor
+                      Is it on a triangle? Returns NaN if condition not satisfied.
+                      Expected size is [1] or [m] depending on the input.
+    """
+    if len(point_to_check.shape) == 1:
+        point_to_check = point_to_check.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    v0 = triangle[:, 2] - triangle[:, 0]
+    v1 = triangle[:, 1] - triangle[:, 0]
+    v2 = point_to_check - triangle[:, 0]
+    if len(v0.shape) == 1:
+        v0 = v0.unsqueeze(0)
+    if len(v1.shape) == 1:
+        v1 = v1.unsqueeze(0)
+    if len(v2.shape) == 1:
+        v2 = v2.unsqueeze(0)
+    dot00 = torch.mm(v0, v0.T)
+    dot01 = torch.mm(v0, v1.T)
+    dot02 = torch.mm(v0, v2.T) 
+    dot11 = torch.mm(v1, v1.T)
+    dot12 = torch.mm(v1, v2.T)
+    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)
+    u = (dot11 * dot02 - dot01 * dot12) * invDenom
+    v = (dot00 * dot12 - dot01 * dot02) * invDenom
+    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)
+    return result
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle_batch(point_to_check, triangle) + +

+ + +
+ +

Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.

+ + +

Parameters:

+
    +
  • + point_to_check + – +
    +
              Points to check (m x n x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
              Triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( torch.tensor (m x n) +) – +
    + +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def is_it_on_triangle_batch(point_to_check, triangle):
+    """
+    Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.
+
+    Parameters
+    ----------
+    point_to_check  : torch.tensor
+                      Points to check (m x n x 3).
+    triangle        : torch.tensor 
+                      Triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    result          : torch.tensor (m x n)
+
+    """
+    if len(point_to_check.shape) == 1:
+        point_to_check = point_to_check.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    v0 = triangle[:, 2] - triangle[:, 0]
+    v1 = triangle[:, 1] - triangle[:, 0]
+    v2 = point_to_check - triangle[:, None, 0]
+    if len(v0.shape) == 1:
+        v0 = v0.unsqueeze(0)
+    if len(v1.shape) == 1:
+        v1 = v1.unsqueeze(0)
+    if len(v2.shape) == 1:
+        v2 = v2.unsqueeze(0)
+
+    dot00 = torch.bmm(v0.unsqueeze(1), v0.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot01 = torch.bmm(v0.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot02 = torch.bmm(v0.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)
+    dot11 = torch.bmm(v1.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot12 = torch.bmm(v1.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)
+    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)
+    u = (dot11 * dot02 - dot01 * dot12) * invDenom
+    v = (dot00 * dot12 - dot01 * dot02) * invDenom
+    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)
+
+    return result
+
+
+
+ +
+ +
+ + +

+ propagate_ray(ray, distance) + +

+ + +
+ +

Definition to propagate a ray at a certain given distance.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
    +
    +
    +
  • +
  • + distance + – +
    +
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_ray ( tensor +) – +
    +

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
def propagate_ray(ray, distance):
+    """
+    Definition to propagate a ray at a certain given distance.
+
+    Parameters
+    ----------
+    ray        : torch.tensor
+                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
+    distance   : torch.tensor
+                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
+
+    Returns
+    ----------
+    new_ray    : torch.tensor
+                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(distance.shape) == 2:
+        distance = distance.squeeze(-1)
+    new_ray = torch.zeros_like(ray)
+    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]
+    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]
+    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]
+    return new_ray
+
+
+
+ +
+ +
+ + +

+ reflect(input_ray, normal) + +

+ + +
+ +

Definition to reflect an incoming ray from a surface defined by a surface normal. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + input_ray + – +
    +
           A ray or rays.
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + normal + – +
    +
           A surface normal(s).
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a reflected ray. +Expected size is [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def reflect(input_ray, normal):
+    """ 
+    Definition to reflect an incoming ray from a surface defined by a surface normal. 
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    input_ray    : torch.tensor
+                   A ray or rays.
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+    normal       : torch.tensor
+                   A surface normal(s).
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    Returns
+    ----------
+    output_ray   : torch.tensor
+                   Array that contains starting points and cosines of a reflected ray.
+                   Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    """
+    if len(input_ray.shape) == 2:
+        input_ray = input_ray.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    mu = 1
+    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8
+    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div
+    a = a.unsqueeze(1)
+    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))
+    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)
+    output_ray[:, 0] = normal[:, 0]
+    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]
+    return output_ray
+
+
+
+ +
+ +
+ + +

+ refract(vector, normvector, n1, n2, error=0.01) + +

+ + +
+ +

Definition to refract an incoming ray. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + vector + – +
    +
             Incoming ray.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
    +
    +
    +
  • +
  • + normvector + – +
    +
             Normal vector.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
    +
    +
    +
  • +
  • + n1 + – +
    +
             Refractive index of the incoming medium.
    +
    +
    +
  • +
  • + n2 + – +
    +
             Refractive index of the outgoing medium.
    +
    +
    +
  • +
  • + error + – +
    +
             Desired error.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Refracted ray. +Expected size is [1, 2, 3]

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
def refract(vector, normvector, n1, n2, error = 0.01):
+    """
+    Definition to refract an incoming ray.
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    vector         : torch.tensor
+                     Incoming ray.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
+    normvector     : torch.tensor
+                     Normal vector.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
+    n1             : float
+                     Refractive index of the incoming medium.
+    n2             : float
+                     Refractive index of the outgoing medium.
+    error          : float 
+                     Desired error.
+
+    Returns
+    -------
+    output         : torch.tensor
+                     Refracted ray.
+                     Expected size is [1, 2, 3]
+    """
+    if len(vector.shape) == 2:
+        vector = vector.unsqueeze(0)
+    if len(normvector.shape) == 2:
+        normvector = normvector.unsqueeze(0)
+    mu    = n1 / n2
+    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2
+    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div
+    b     = (mu ** 2 - 1) / div
+    to    = - b * 0.5 / a
+    num   = 0
+    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2
+    while len(eps[eps > error]) > 0:
+       num   += 1
+       oldto  = to
+       v      = to ** 2 + 2 * a * to + b
+       deltav = 2 * (to + a)
+       to     = to - v / deltav
+       eps    = abs(oldto - to)
+    output = torch.zeros_like(vector)
+    output[:, 0, 0] = normvector[:, 0, 0]
+    output[:, 0, 1] = normvector[:, 0, 1]
+    output[:, 0, 2] = normvector[:, 0, 2]
+    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]
+    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]
+    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]
+    return output
+
+
+
+ +
+ +
+ + +

+ rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]])) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point with size of [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis.
    +       There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +       Expected size is [3] or [1, 3].
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +       Expected size is [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the rotation [1 x 3] or [m x 3].

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix along X axis [3 x 3].

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix along Y axis [3 x 3].

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix along Z axis [3 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def rotate_points(
+                 point,
+                 angles = torch.tensor([[0, 0, 0]]), 
+                 mode='XYZ', 
+                 origin = torch.tensor([[0, 0, 0]]), 
+                 offset = torch.tensor([[0, 0, 0]])
+                ):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : torch.tensor
+                   A point with size of [3] or [1, 3] or [m, 3].
+    angles       : torch.tensor
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis.
+                   There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : torch.tensor
+                   Reference point for a rotation.
+                   Expected size is [3] or [1, 3].
+    offset       : torch.tensor
+                   Shift with the given offset.
+                   Expected size is [3] or [1, 3] or [m, 3].
+
+    Returns
+    ----------
+    result       : torch.tensor
+                   Result of the rotation [1 x 3] or [m x 3].
+    rotx         : torch.tensor
+                   Rotation matrix along X axis [3 x 3].
+    roty         : torch.tensor
+                   Rotation matrix along Y axis [3 x 3].
+    rotz         : torch.tensor
+                   Rotation matrix along Z axis [3 x 3].
+    """
+    origin = origin.to(point.device)
+    offset = offset.to(point.device)
+    if len(point.shape) == 1:
+        point = point.unsqueeze(0)
+    if len(angles.shape) == 1:
+        angles = angles.unsqueeze(0)
+    rotx = rotmatx(angles[:, 0])
+    roty = rotmaty(angles[:, 1])
+    rotz = rotmatz(angles[:, 2])
+    new_point = (point - origin).T
+    if mode == 'XYZ':
+        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T
+    elif mode == 'XZY':
+        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T
+    elif mode == 'YXZ':
+        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T
+    elif mode == 'ZXY':
+        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T
+    elif mode == 'ZYX':
+        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T
+    result += origin
+    result += offset
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = torch.subtract(b, a)
+    p1a = torch.subtract(p1, a)
+    p2a = torch.subtract(p2, a)
+    cp1 = torch.cross(ba, p1a)
+    cp2 = torch.cross(ba, p2a)
+    test = torch.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ save_torch_tensor(fn, tensor) + +

+ + +
+ +

Definition to save a torch tensor.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + tensor + – +
    +
           Torch tensor to be saved.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
def save_torch_tensor(fn, tensor):
+    """
+    Definition to save a torch tensor.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    tensor       : torch.tensor
+                   Torch tensor to be saved.
+    """ 
+    torch.save(tensor, expanduser(fn))
+
+
+
+ +
+ +
+ + +

+ write_PLY(triangles, savefn='output.ply') + +

+ + +
+ +

Definition to generate a PLY file from given points.

+ + +

Parameters:

+
    +
  • + triangles + – +
    +
          List of triangles with the size of Mx3x3.
    +
    +
    +
  • +
  • + savefn + – +
    +
          Filename for a PLY file.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
def write_PLY(triangles, savefn = 'output.ply'):
+    """
+    Definition to generate a PLY file from given points.
+
+    Parameters
+    ----------
+    triangles   : ndarray
+                  List of triangles with the size of Mx3x3.
+    savefn      : string
+                  Filename for a PLY file.
+    """
+    tris = []
+    pnts = []
+    color = [255, 255, 255]
+    for tri_id in range(triangles.shape[0]):
+        tris.append(
+            (
+                [3*tri_id, 3*tri_id+1, 3*tri_id+2],
+                color[0],
+                color[1],
+                color[2]
+            )
+        )
+        for i in range(0, 3):
+            pnts.append(
+                (
+                    float(triangles[tri_id][i][0]),
+                    float(triangles[tri_id][i][1]),
+                    float(triangles[tri_id][i][2])
+                )
+            )
+    tris = np.asarray(tris, dtype=[
+                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+    # Save mesh.
+    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])
+    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])
+    PlyData([el1, el2], text="True").write(savefn)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_sphere_normal_torch(point, sphere) + +

+ + +
+ +

Definition to get a normal of a point on a given sphere.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Point on sphere in X,Y,Z.
    +
    +
    +
  • +
  • + sphere + – +
    +
            Center defined in X,Y,Z and radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( tensor +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def get_sphere_normal_torch(point, sphere):
+    """
+    Definition to get a normal of a point on a given sphere.
+
+    Parameters
+    ----------
+    point         : torch.tensor
+                    Point on sphere in X,Y,Z.
+    sphere        : torch.tensor
+                    Center defined in X,Y,Z and radius.
+
+    Returns
+    ----------
+    normal_vector : torch.tensor
+                    Normal vector.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    normal_vector = create_ray_from_two_points(point, sphere[0:3])
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_triangle_normal(triangle, triangle_center=None) + +

+ + +
+ +

Definition to calculate surface normal of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
    +
    +
    +
  • +
  • + triangle_center + (tensor, default: + None +) + – +
    +
              Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def get_triangle_normal(triangle, triangle_center=None):
+    """
+    Definition to calculate surface normal of a triangle.
+
+    Parameters
+    ----------
+    triangle        : torch.tensor
+                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
+    triangle_center : torch.tensor
+                      Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
+
+    Returns
+    ----------
+    normal          : torch.tensor
+                      Surface normal at the point of intersection.
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.view((1, 3, 3))
+    normal = torch.zeros((triangle.shape[0], 2, 3)).to(triangle.device)
+    direction = torch.linalg.cross(
+                                   triangle[:, 0] - triangle[:, 1], 
+                                   triangle[:, 2] - triangle[:, 1]
+                                  )
+    if type(triangle_center) == type(None):
+        normal[:, 0] = center_of_triangle(triangle)
+    else:
+        normal[:, 0] = triangle_center
+    normal[:, 1] = direction / torch.sum(direction, axis=1)[0]
+    if normal.shape[0] == 1:
+        normal = normal.view((2, 3))
+    return normal
+
+
+
+ +
+ +
+ + +

+ intersect_w_circle(ray, circle) + +

+ + +
+ +

Definition to find intersection point of a ray with a circle. +Returns distance as zero if there isn't an intersection.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + circle + – +
    +
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( Tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( Tensor +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_circle(ray, circle):
+    """
+    Definition to find intersection point of a ray with a circle. 
+    Returns distance as zero if there isn't an intersection.
+
+    Parameters
+    ----------
+    ray          : torch.Tensor
+                   A vector/ray.
+    circle       : list
+                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
+
+    Returns
+    ----------
+    normal       : torch.Tensor
+                   Surface normal at the point of intersection.
+    distance     : torch.Tensor
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, circle[0])
+
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+
+    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])
+    mask = distance_to_center > circle[2]
+    distance[mask] = 0
+
+    if len(ray.shape) == 2:
+        normal = normal.squeeze(0)
+
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_sphere(ray, sphere, learning_rate=0.2, number_of_steps=5000, error_threshold=0.01) + +

+ + +
+ +

Definition to find the intersection between ray(s) and sphere(s).

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  Input ray(s).
    +              Expected size is [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + sphere + – +
    +
                  Input sphere.
    +              Expected size is [1 x 4].
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                  Learning rate used in the optimizer for finding the propagation distances of the rays.
    +
    +
    +
  • +
  • + number_of_steps + – +
    +
                  Number of steps used in the optimizer.
    +
    +
    +
  • +
  • + error_threshold + – +
    +
                  The error threshold that will help deciding intersection or no intersection.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +intersecting_ray ( tensor +) – +
    +

    Ray(s) that intersecting with the given sphere. +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normal(s) for the ray(s) intersecting with the given sphere +Expected size is [n x 2 x 3], where n could be any real number.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):
+    """
+    Definition to find the intersection between ray(s) and sphere(s).
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          Input ray(s).
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    sphere              : torch.tensor
+                          Input sphere.
+                          Expected size is [1 x 4].
+    learning_rate       : float
+                          Learning rate used in the optimizer for finding the propagation distances of the rays.
+    number_of_steps     : int
+                          Number of steps used in the optimizer.
+    error_threshold     : float
+                          The error threshold that will help deciding intersection or no intersection.
+
+    Returns
+    -------
+    intersecting_ray    : torch.tensor
+                          Ray(s) that intersecting with the given sphere.
+                          Expected size is [n x 2 x 3], where n could be any real number.
+    intersecting_normal : torch.tensor
+                          Normal(s) for the ray(s) intersecting with the given sphere
+                          Expected size is [n x 2 x 3], where n could be any real number.
+
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(sphere.shape) == 1:
+        sphere = sphere.unsqueeze(0)
+    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)
+    loss_l2 = torch.nn.MSELoss(reduction = 'sum')
+    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    
+    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)
+    for step in t:
+        optimizer.zero_grad()
+        propagated_ray = propagate_ray(ray, distance)
+        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)
+        loss = loss_l2(
+                       test,
+                       torch.zeros_like(test)
+                      )
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description('Sphere intersection loss: {}'.format(loss.item()))
+    check = test < error_threshold
+    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])
+    intersecting_normal = create_ray_from_two_points(
+                                                     sphere[:, 0:3],
+                                                     intersecting_ray[:, 0]
+                                                    )
+    return intersecting_ray, intersecting_normal, distance, check
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface(ray, points) + +

+ + +
+ +

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + points + – +
    +
           Set of points in X,Y and Z to define a planar surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_surface(ray, points):
+    """
+    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray.
+    points       : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface.
+
+    Returns
+    ----------
+    normal       : torch.tensor
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between starting point of a ray with it's intersection with a planar surface.
+    """
+    normal = get_triangle_normal(points)
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(points.shape) == 2:
+        points = points.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    f = normal[:, 0] - ray[:, 0]
+    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T
+    new_normal = torch.zeros_like(ray)
+    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]
+    new_normal[:, 1] = normal[:, 1]
+    new_normal = torch.nan_to_num(
+                                  new_normal,
+                                  nan = float('nan'),
+                                  posinf = float('nan'),
+                                  neginf = float('nan')
+                                 )
+    distance = torch.nan_to_num(
+                                distance,
+                                nan = float('nan'),
+                                posinf = float('nan'),
+                                neginf = float('nan')
+                               )
+    return new_normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface_batch(ray, triangle) + +

+ + +
+ + + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection (m x n x 2 x 3).

    +
    +
  • +
  • +distance ( tensor +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_surface_batch(ray, triangle):
+    """
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal       : torch.tensor
+                   Surface normal at the point of intersection (m x n x 2 x 3).
+    distance     : torch.tensor
+                   Distance in between starting point of a ray with it's intersection with a planar surface (m x n).
+    """
+    normal = get_triangle_normal(triangle)
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+
+    f = normal[:, None, 0] - ray[None, :, 0]
+    distance = (torch.bmm(normal[:, None, 1], f.permute(0, 2, 1)).squeeze(1) / torch.mm(normal[:, 1], ray[:, 1].T)).T
+
+    new_normal = torch.zeros((triangle.shape[0], )+ray.shape)
+    new_normal[:, :, 0] = ray[None, :, 0] + (distance[:, :, None] * ray[:, None, 1]).permute(1, 0, 2)
+    new_normal[:, :, 1] = normal[:, None, 1]
+    new_normal = torch.nan_to_num(
+                                  new_normal,
+                                  nan = float('nan'),
+                                  posinf = float('nan'),
+                                  neginf = float('nan')
+                                 )
+    distance = torch.nan_to_num(
+                                distance,
+                                nan = float('nan'),
+                                posinf = float('nan'),
+                                neginf = float('nan')
+                               )
+    return new_normal, distance.T
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle(ray, triangle) + +

+ + +
+ +

Definition to find intersection point of a ray with a triangle.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection with the surface of triangle. +This could also involve surface normals that are not on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle. +Expected size is [1 x 1] or [m x 1] depending on the input.

    +
    +
  • +
  • +intersecting_ray ( tensor +) – +
    +

    Rays that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +intersecting_normal ( tensor +) – +
    +

    Normals that intersect with the triangle plane and on the triangle. +Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

    +
    +
  • +
  • +check ( tensor +) – +
    +

    A list that provides a bool as True or False for each ray used as input. +A test to see is a ray could be on the given triangle. +Expected size is [1] or [m].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_triangle(ray, triangle):
+    """
+    Definition to find intersection point of a ray with a triangle. 
+
+    Parameters
+    ----------
+    ray                 : torch.tensor
+                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].
+    triangle            : torch.tensor
+                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].
+
+    Returns
+    ----------
+    normal              : torch.tensor
+                          Surface normal at the point of intersection with the surface of triangle.
+                          This could also involve surface normals that are not on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    distance            : float
+                          Distance in between a starting point of a ray and the intersection point with a given triangle.
+                          Expected size is [1 x 1] or [m x 1] depending on the input.
+    intersecting_ray    : torch.tensor
+                          Rays that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    intersecting_normal : torch.tensor
+                          Normals that intersect with the triangle plane and on the triangle.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.
+    check               : torch.tensor
+                          A list that provides a bool as True or False for each ray used as input.
+                          A test to see is a ray could be on the given triangle.
+                          Expected size is [1] or [m].
+    """
+    if len(triangle.shape) == 2:
+       triangle = triangle.unsqueeze(0)
+    if len(ray.shape) == 2:
+       ray = ray.unsqueeze(0)
+    normal, distance = intersect_w_surface(ray, triangle)
+    check = is_it_on_triangle(normal[:, 0], triangle)
+    intersecting_ray = ray.unsqueeze(0)
+    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_ray = intersecting_ray[check == True]
+    intersecting_normal = normal.unsqueeze(0)
+    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)
+    intersecting_normal = intersecting_normal[check ==  True]
+    return normal, distance, intersecting_ray, intersecting_normal, check
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle_batch(ray, triangle) + +

+ + +
+ +

Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           vectors/rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( tensor +) – +
    +

    Surface normal at the point of intersection (m x n x 2 x 3).

    +
    +
  • +
  • +distance ( List +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

    +
    +
  • +
  • +intersect_ray ( List +) – +
    +

    List of intersecting rays (k x 2 x 3) where k <= n.

    +
    +
  • +
  • +intersect_normal ( List +) – +
    +

    List of intersecting normals (k x 2 x 3) where k <= n*m.

    +
    +
  • +
  • +check ( tensor +) – +
    +

    Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
def intersect_w_triangle_batch(ray, triangle):
+    """
+    Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   vectors/rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal          : torch.tensor
+                      Surface normal at the point of intersection (m x n x 2 x 3).
+    distance        : List
+                      Distance in between starting point of a ray with it's intersection with a planar surface (m x n).
+    intersect_ray   : List
+                      List of intersecting rays (k x 2 x 3) where k <= n.
+    intersect_normal: List
+                      List of intersecting normals (k x 2 x 3) where k <= n*m.
+    check           : torch.tensor
+                      Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.
+    """
+    if len(triangle.shape) == 2:
+       triangle = triangle.unsqueeze(0)
+    if len(ray.shape) == 2:
+       ray = ray.unsqueeze(0)
+
+    normal, distance = intersect_w_surface_batch(ray, triangle)
+
+    check = is_it_on_triangle_batch(normal[:, :, 0], triangle)
+
+    flat_check = check.flatten()
+    flat_normal = normal.view(-1, normal.size(-2), normal.size(-1))
+    flat_ray = ray.repeat(normal.size(0), 1, 1)
+    flat_distance = distance.flatten()
+
+    filtered_normal = torch.masked_select(flat_normal, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))
+    filtered_ray = torch.masked_select(flat_ray, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))
+    filtered_distnace = torch.masked_select(flat_distance, flat_check)
+
+    check_count = check.sum(dim=1).tolist()
+    split_size_ray_and_normal = [count * 2 * 3 for count in check_count]
+    split_size_distance = [count for count in check_count]
+
+    normal_grouped = torch.split(filtered_normal, split_size_ray_and_normal)
+    ray_grouped = torch.split(filtered_ray, split_size_ray_and_normal)
+    distance_grouped = torch.split(filtered_distnace, split_size_distance)
+
+    intersecting_normal = [g.view(-1, 2, 3) for g in normal_grouped if g.numel() > 0]
+    intersecting_ray = [g.view(-1, 2, 3) for g in ray_grouped if g.numel() > 0]
+    new_distance = [g for g in distance_grouped if g.numel() > 0]
+
+    return normal, new_distance, intersecting_ray, intersecting_normal, check
+
+
+
+ +
+ +
+ + +

+ reflect(input_ray, normal) + +

+ + +
+ +

Definition to reflect an incoming ray from a surface defined by a surface normal. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + input_ray + – +
    +
           A ray or rays.
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + normal + – +
    +
           A surface normal(s).
    +       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a reflected ray. +Expected size is [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def reflect(input_ray, normal):
+    """ 
+    Definition to reflect an incoming ray from a surface defined by a surface normal. 
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    input_ray    : torch.tensor
+                   A ray or rays.
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+    normal       : torch.tensor
+                   A surface normal(s).
+                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    Returns
+    ----------
+    output_ray   : torch.tensor
+                   Array that contains starting points and cosines of a reflected ray.
+                   Expected size is [1 x 2 x 3] or [m x 2 x 3].
+    """
+    if len(input_ray.shape) == 2:
+        input_ray = input_ray.unsqueeze(0)
+    if len(normal.shape) == 2:
+        normal = normal.unsqueeze(0)
+    mu = 1
+    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8
+    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div
+    a = a.unsqueeze(1)
+    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))
+    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)
+    output_ray[:, 0] = normal[:, 0]
+    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]
+    return output_ray
+
+
+
+ +
+ +
+ + +

+ refract(vector, normvector, n1, n2, error=0.01) + +

+ + +
+ +

Definition to refract an incoming ray. +Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + vector + – +
    +
             Incoming ray.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
    +
    +
    +
  • +
  • + normvector + – +
    +
             Normal vector.
    +         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
    +
    +
    +
  • +
  • + n1 + – +
    +
             Refractive index of the incoming medium.
    +
    +
    +
  • +
  • + n2 + – +
    +
             Refractive index of the outgoing medium.
    +
    +
    +
  • +
  • + error + – +
    +
             Desired error.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Refracted ray. +Expected size is [1, 2, 3]

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/boundary.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
def refract(vector, normvector, n1, n2, error = 0.01):
+    """
+    Definition to refract an incoming ray.
+    Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+
+    Parameters
+    ----------
+    vector         : torch.tensor
+                     Incoming ray.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].
+    normvector     : torch.tensor
+                     Normal vector.
+                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].
+    n1             : float
+                     Refractive index of the incoming medium.
+    n2             : float
+                     Refractive index of the outgoing medium.
+    error          : float 
+                     Desired error.
+
+    Returns
+    -------
+    output         : torch.tensor
+                     Refracted ray.
+                     Expected size is [1, 2, 3]
+    """
+    if len(vector.shape) == 2:
+        vector = vector.unsqueeze(0)
+    if len(normvector.shape) == 2:
+        normvector = normvector.unsqueeze(0)
+    mu    = n1 / n2
+    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2
+    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div
+    b     = (mu ** 2 - 1) / div
+    to    = - b * 0.5 / a
+    num   = 0
+    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2
+    while len(eps[eps > error]) > 0:
+       num   += 1
+       oldto  = to
+       v      = to ** 2 + 2 * a * to + b
+       deltav = 2 * (to + a)
+       to     = to - v / deltav
+       eps    = abs(oldto - to)
+    output = torch.zeros_like(vector)
+    output[:, 0, 0] = normvector[:, 0, 0]
+    output[:, 0, 1] = normvector[:, 0, 1]
+    output[:, 0, 2] = normvector[:, 0, 2]
+    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]
+    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]
+    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + +

A class to represent a detector.

+ + + + + + +
+ Source code in odak/learn/raytracing/detector.py +
class detector():
+    """
+    A class to represent a detector.
+    """
+
+
+    def __init__(
+                 self,
+                 colors = 3,
+                 center = torch.tensor([0., 0., 0.]),
+                 tilt = torch.tensor([0., 0., 0.]),
+                 size = torch.tensor([10., 10.]),
+                 resolution = torch.tensor([100, 100]),
+                 device = torch.device('cpu')
+                ):
+        """
+        Parameters
+        ----------
+        colors         : int
+                         Number of color channels to register (e.g., RGB).
+        center         : torch.tensor
+                         Center point of the detector [3].
+        tilt           : torch.tensor
+                         Tilt angles of the surface in degrees [3].
+        size           : torch.tensor
+                         Size of the detector [2].
+        resolution     : torch.tensor
+                         Resolution of the detector.
+        device         : torch.device
+                         Device for computation (e.g., cuda, cpu).
+        """
+        self.device = device
+        self.colors = colors
+        self.resolution = resolution.to(self.device)
+        self.surface_center = center.to(self.device)
+        self.surface_tilt = tilt.to(self.device)
+        self.size = size.to(self.device)
+        self.pixel_size = torch.tensor([
+                                        self.size[0] / self.resolution[0],
+                                        self.size[1] / self.resolution[1]
+                                       ], device  = self.device)
+        self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)
+        self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.
+        self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)
+        self.plane = define_plane(
+                                  point = self.surface_center,
+                                  angles = self.surface_tilt
+                                 )
+        self.pixel_locations, _, _, _ = grid_sample(
+                                                    size = self.size.tolist(),
+                                                    no = self.resolution.tolist(),
+                                                    center = self.surface_center.tolist(),
+                                                    angles = self.surface_tilt.tolist()
+                                                   )
+        self.pixel_locations = self.pixel_locations.to(self.device)
+        self.relu = torch.nn.ReLU()
+        self.clear()
+
+
+    def intersect(self, rays, color = 0):
+        """
+        Function to intersect rays with the detector
+
+
+        Parameters
+        ----------
+        rays            : torch.tensor
+                          Rays to be intersected with a detector.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+        color           : int
+                          Color channel to register.
+
+        Returns
+        -------
+        points          : torch.tensor
+                          Intersection points with the image detector [k x 3].
+        """
+        normals, _ = intersect_w_surface(rays, self.plane)
+        points = normals[:, 0]
+        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))
+        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))
+        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))
+        hit_x = torch.clamp(distances_x, min = 0., max = 1.)
+        hit_y = torch.clamp(distances_y, min = 0., max = 1.)
+        hit = hit_x * hit_y
+        image = torch.sum(hit, dim = 0)
+        self.image[color] += image.reshape(
+                                           self.image.shape[-2], 
+                                           self.image.shape[-1]
+                                          )
+        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)
+        distance_image = distances
+#        distance_image = distances.reshape(
+#                                           -1,
+#                                           self.image.shape[-2],
+#                                           self.image.shape[-1]
+#                                          )
+        return points, image, distance_image
+
+
+    def get_image(self):
+        """
+        Function to return the detector image.
+
+        Returns
+        -------
+        image           : torch.tensor
+                          Detector image.
+        """
+        image = (self.image - self.image.min()) / (self.image.max() - self.image.min())
+        return image
+
+
+    def clear(self):
+        """
+        Internal function to clear a detector.
+        """
+        self.image = torch.zeros(
+
+                                 self.colors,
+                                 self.resolution[0],
+                                 self.resolution[1],
+                                 device = self.device,
+                                )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(colors=3, center=torch.tensor([0.0, 0.0, 0.0]), tilt=torch.tensor([0.0, 0.0, 0.0]), size=torch.tensor([10.0, 10.0]), resolution=torch.tensor([100, 100]), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + colors + – +
    +
             Number of color channels to register (e.g., RGB).
    +
    +
    +
  • +
  • + center + – +
    +
             Center point of the detector [3].
    +
    +
    +
  • +
  • + tilt + – +
    +
             Tilt angles of the surface in degrees [3].
    +
    +
    +
  • +
  • + size + – +
    +
             Size of the detector [2].
    +
    +
    +
  • +
  • + resolution + – +
    +
             Resolution of the detector.
    +
    +
    +
  • +
  • + device + – +
    +
             Device for computation (e.g., cuda, cpu).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
def __init__(
+             self,
+             colors = 3,
+             center = torch.tensor([0., 0., 0.]),
+             tilt = torch.tensor([0., 0., 0.]),
+             size = torch.tensor([10., 10.]),
+             resolution = torch.tensor([100, 100]),
+             device = torch.device('cpu')
+            ):
+    """
+    Parameters
+    ----------
+    colors         : int
+                     Number of color channels to register (e.g., RGB).
+    center         : torch.tensor
+                     Center point of the detector [3].
+    tilt           : torch.tensor
+                     Tilt angles of the surface in degrees [3].
+    size           : torch.tensor
+                     Size of the detector [2].
+    resolution     : torch.tensor
+                     Resolution of the detector.
+    device         : torch.device
+                     Device for computation (e.g., cuda, cpu).
+    """
+    self.device = device
+    self.colors = colors
+    self.resolution = resolution.to(self.device)
+    self.surface_center = center.to(self.device)
+    self.surface_tilt = tilt.to(self.device)
+    self.size = size.to(self.device)
+    self.pixel_size = torch.tensor([
+                                    self.size[0] / self.resolution[0],
+                                    self.size[1] / self.resolution[1]
+                                   ], device  = self.device)
+    self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)
+    self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.
+    self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)
+    self.plane = define_plane(
+                              point = self.surface_center,
+                              angles = self.surface_tilt
+                             )
+    self.pixel_locations, _, _, _ = grid_sample(
+                                                size = self.size.tolist(),
+                                                no = self.resolution.tolist(),
+                                                center = self.surface_center.tolist(),
+                                                angles = self.surface_tilt.tolist()
+                                               )
+    self.pixel_locations = self.pixel_locations.to(self.device)
+    self.relu = torch.nn.ReLU()
+    self.clear()
+
+
+
+ +
+ +
+ + +

+ clear() + +

+ + +
+ +

Internal function to clear a detector.

+ +
+ Source code in odak/learn/raytracing/detector.py +
def clear(self):
+    """
+    Internal function to clear a detector.
+    """
+    self.image = torch.zeros(
+
+                             self.colors,
+                             self.resolution[0],
+                             self.resolution[1],
+                             device = self.device,
+                            )
+
+
+
+ +
+ +
+ + +

+ get_image() + +

+ + +
+ +

Function to return the detector image.

+ + +

Returns:

+
    +
  • +image ( tensor +) – +
    +

    Detector image.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
def get_image(self):
+    """
+    Function to return the detector image.
+
+    Returns
+    -------
+    image           : torch.tensor
+                      Detector image.
+    """
+    image = (self.image - self.image.min()) / (self.image.max() - self.image.min())
+    return image
+
+
+
+ +
+ +
+ + +

+ intersect(rays, color=0) + +

+ + +
+ +

Function to intersect rays with the detector

+ + +

Parameters:

+
    +
  • + rays + – +
    +
              Rays to be intersected with a detector.
    +          Expected size is [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
  • + color + – +
    +
              Color channel to register.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +points ( tensor +) – +
    +

    Intersection points with the image detector [k x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/detector.py +
    def intersect(self, rays, color = 0):
+        """
+        Function to intersect rays with the detector
+
+
+        Parameters
+        ----------
+        rays            : torch.tensor
+                          Rays to be intersected with a detector.
+                          Expected size is [1 x 2 x 3] or [m x 2 x 3].
+        color           : int
+                          Color channel to register.
+
+        Returns
+        -------
+        points          : torch.tensor
+                          Intersection points with the image detector [k x 3].
+        """
+        normals, _ = intersect_w_surface(rays, self.plane)
+        points = normals[:, 0]
+        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))
+        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))
+        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))
+        hit_x = torch.clamp(distances_x, min = 0., max = 1.)
+        hit_y = torch.clamp(distances_y, min = 0., max = 1.)
+        hit = hit_x * hit_y
+        image = torch.sum(hit, dim = 0)
+        self.image[color] += image.reshape(
+                                           self.image.shape[-2], 
+                                           self.image.shape[-1]
+                                          )
+        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)
+        distance_image = distances
+#        distance_image = distances.reshape(
+#                                           -1,
+#                                           self.image.shape[-2],
+#                                           self.image.shape[-1]
+#                                          )
+        return points, image, distance_image
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ planar_mesh + + +

+ + +
+ + + + + + + +
+ Source code in odak/learn/raytracing/mesh.py +
class planar_mesh():
+
+
+    def __init__(
+                 self,
+                 size = [1., 1.],
+                 number_of_meshes = [10, 10],
+                 angles = torch.tensor([0., 0., 0.]),
+                 offset = torch.tensor([0., 0., 0.]),
+                 device = torch.device('cpu'),
+                 heights = None
+                ):
+        """
+        Definition to generate a plane with meshes.
+
+
+        Parameters
+        -----------
+        number_of_meshes  : torch.tensor
+                            Number of squares over plane.
+                            There are two triangles at each square.
+        size              : torch.tensor
+                            Size of the plane.
+        angles            : torch.tensor
+                            Rotation angles in degrees.
+        offset            : torch.tensor
+                            Offset along XYZ axes.
+                            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+        device            : torch.device
+                            Computational resource to be used (e.g., cpu, cuda).
+        heights           : torch.tensor
+                            Load surface heights from a tensor.
+        """
+        self.device = device
+        self.angles = angles.to(self.device)
+        self.offset = offset.to(self.device)
+        self.size = size.to(self.device)
+        self.number_of_meshes = number_of_meshes.to(self.device)
+        self.init_heights(heights)
+
+
+    def init_heights(self, heights = None):
+        """
+        Internal function to initialize a height map.
+        Note that self.heights is a differentiable variable, and can be optimized or learned.
+        See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.
+        """
+        if not isinstance(heights, type(None)):
+            self.heights = heights.to(self.device)
+            self.heights.requires_grad = True
+        else:
+            self.heights = torch.zeros(
+                                       (self.number_of_meshes[0], self.number_of_meshes[1], 1),
+                                       requires_grad = True,
+                                       device = self.device,
+                                      )
+        x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) 
+        y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)
+        X, Y = torch.meshgrid(x, y, indexing = 'ij')
+        self.X = X.unsqueeze(-1)
+        self.Y = Y.unsqueeze(-1)
+
+
+    def save_heights(self, filename = 'heights.pt'):
+        """
+        Function to save heights to a file.
+
+        Parameters
+        ----------
+        filename          : str
+                            Filename.
+        """
+        save_torch_tensor(filename, self.heights.detach().clone())
+
+
+    def save_heights_as_PLY(self, filename = 'mesh.ply'):
+        """
+        Function to save mesh to a PLY file.
+
+        Parameters
+        ----------
+        filename          : str
+                            Filename.
+        """
+        triangles = self.get_triangles()
+        write_PLY(triangles, filename)
+
+
+    def get_squares(self):
+        """
+        Internal function to initiate squares over a plane.
+
+        Returns
+        -------
+        squares     : torch.tensor
+                      Squares over a plane.
+                      Expected size is [m x n x 3].
+        """
+        squares = torch.cat((
+                             self.X,
+                             self.Y,
+                             self.heights
+                            ), dim = -1)
+        return squares
+
+
+    def get_triangles(self):
+        """
+        Internal function to get triangles.
+        """ 
+        squares = self.get_squares()
+        triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)
+        for i in range(0, self.number_of_meshes[0] - 1):
+            for j in range(0, self.number_of_meshes[1] - 1):
+                first_triangle = torch.cat((
+                                            squares[i + 1, j].unsqueeze(0),
+                                            squares[i + 1, j + 1].unsqueeze(0),
+                                            squares[i, j + 1].unsqueeze(0),
+                                           ), dim = 0)
+                second_triangle = torch.cat((
+                                             squares[i + 1, j].unsqueeze(0),
+                                             squares[i, j + 1].unsqueeze(0),
+                                             squares[i, j].unsqueeze(0),
+                                            ), dim = 0)
+                triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)
+                triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)
+        triangles = triangles.view(-1, 3, 3) + self.offset
+        return triangles 
+
+
+    def mirror(self, rays):
+        """
+        Function to bounce light rays off the meshes.
+
+        Parameters
+        ----------
+        rays              : torch.tensor
+                            Rays to be bounced.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+        Returns
+        -------
+        reflected_rays    : torch.tensor
+                            Reflected rays.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+        reflected_normals : torch.tensor
+                            Reflected normals.
+                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+        """
+        if len(rays.shape) == 2:
+            rays = rays.unsqueeze(0)
+        triangles = self.get_triangles()
+        reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+        reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+        for triangle in triangles:
+            _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(
+                                                                                        rays,
+                                                                                        triangle
+                                                                                       ) 
+            triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)
+            if triangle_reflected_rays.shape[0] > 0:
+                reflected_rays = torch.cat((
+                                            reflected_rays,
+                                            triangle_reflected_rays
+                                          ))
+                reflected_normals = torch.cat((
+                                               reflected_normals,
+                                               intersecting_normals
+                                              ))
+        return reflected_rays, reflected_normals
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(size=[1.0, 1.0], number_of_meshes=[10, 10], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([0.0, 0.0, 0.0]), device=torch.device('cpu'), heights=None) + +

+ + +
+ +

Definition to generate a plane with meshes.

+ + +

Parameters:

+
    +
  • + number_of_meshes + – +
    +
                Number of squares over plane.
    +            There are two triangles at each square.
    +
    +
    +
  • +
  • + size + – +
    +
                Size of the plane.
    +
    +
    +
  • +
  • + angles + – +
    +
                Rotation angles in degrees.
    +
    +
    +
  • +
  • + offset + – +
    +
                Offset along XYZ axes.
    +            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
    +            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
    +
    +
    +
  • +
  • + device + – +
    +
                Computational resource to be used (e.g., cpu, cuda).
    +
    +
    +
  • +
  • + heights + – +
    +
                Load surface heights from a tensor.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def __init__(
+             self,
+             size = [1., 1.],
+             number_of_meshes = [10, 10],
+             angles = torch.tensor([0., 0., 0.]),
+             offset = torch.tensor([0., 0., 0.]),
+             device = torch.device('cpu'),
+             heights = None
+            ):
+    """
+    Definition to generate a plane with meshes.
+
+
+    Parameters
+    -----------
+    number_of_meshes  : torch.tensor
+                        Number of squares over plane.
+                        There are two triangles at each square.
+    size              : torch.tensor
+                        Size of the plane.
+    angles            : torch.tensor
+                        Rotation angles in degrees.
+    offset            : torch.tensor
+                        Offset along XYZ axes.
+                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+    device            : torch.device
+                        Computational resource to be used (e.g., cpu, cuda).
+    heights           : torch.tensor
+                        Load surface heights from a tensor.
+    """
+    self.device = device
+    self.angles = angles.to(self.device)
+    self.offset = offset.to(self.device)
+    self.size = size.to(self.device)
+    self.number_of_meshes = number_of_meshes.to(self.device)
+    self.init_heights(heights)
+
+
+
+ +
+ +
+ + +

+ get_squares() + +

+ + +
+ +

Internal function to initiate squares over a plane.

+ + +

Returns:

+
    +
  • +squares ( tensor +) – +
    +

    Squares over a plane. +Expected size is [m x n x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
def get_squares(self):
+    """
+    Internal function to initiate squares over a plane.
+
+    Returns
+    -------
+    squares     : torch.tensor
+                  Squares over a plane.
+                  Expected size is [m x n x 3].
+    """
+    squares = torch.cat((
+                         self.X,
+                         self.Y,
+                         self.heights
+                        ), dim = -1)
+    return squares
+
+
+
+ +
+ +
+ + +

+ get_triangles() + +

+ + +
+ +

Internal function to get triangles.

+ +
+ Source code in odak/learn/raytracing/mesh.py +
def get_triangles(self):
+    """
+    Internal function to get triangles.
+    """ 
+    squares = self.get_squares()
+    triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)
+    for i in range(0, self.number_of_meshes[0] - 1):
+        for j in range(0, self.number_of_meshes[1] - 1):
+            first_triangle = torch.cat((
+                                        squares[i + 1, j].unsqueeze(0),
+                                        squares[i + 1, j + 1].unsqueeze(0),
+                                        squares[i, j + 1].unsqueeze(0),
+                                       ), dim = 0)
+            second_triangle = torch.cat((
+                                         squares[i + 1, j].unsqueeze(0),
+                                         squares[i, j + 1].unsqueeze(0),
+                                         squares[i, j].unsqueeze(0),
+                                        ), dim = 0)
+            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)
+            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)
+    triangles = triangles.view(-1, 3, 3) + self.offset
+    return triangles 
+
+
+
+ +
+ +
+ + +

+ init_heights(heights=None) + +

+ + +
+ +

Internal function to initialize a height map. +Note that self.heights is a differentiable variable, and can be optimized or learned. +See unit test test/test_learn_ray_detector.py or test/test_learn_ray_mesh.py as examples.

+ +
+ Source code in odak/learn/raytracing/mesh.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def init_heights(self, heights = None):
+    """
+    Internal function to initialize a height map.
+    Note that self.heights is a differentiable variable, and can be optimized or learned.
+    See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.
+    """
+    if not isinstance(heights, type(None)):
+        self.heights = heights.to(self.device)
+        self.heights.requires_grad = True
+    else:
+        self.heights = torch.zeros(
+                                   (self.number_of_meshes[0], self.number_of_meshes[1], 1),
+                                   requires_grad = True,
+                                   device = self.device,
+                                  )
+    x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) 
+    y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)
+    X, Y = torch.meshgrid(x, y, indexing = 'ij')
+    self.X = X.unsqueeze(-1)
+    self.Y = Y.unsqueeze(-1)
+
+
+
+ +
+ +
+ + +

+ mirror(rays) + +

+ + +
+ +

Function to bounce light rays off the meshes.

+ + +

Parameters:

+
    +
  • + rays + – +
    +
                Rays to be bounced.
    +            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +reflected_rays ( tensor +) – +
    +

    Reflected rays. +Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
  • +reflected_normals ( tensor +) – +
    +

    Reflected normals. +Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
def mirror(self, rays):
+    """
+    Function to bounce light rays off the meshes.
+
+    Parameters
+    ----------
+    rays              : torch.tensor
+                        Rays to be bounced.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    Returns
+    -------
+    reflected_rays    : torch.tensor
+                        Reflected rays.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+    reflected_normals : torch.tensor
+                        Reflected normals.
+                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].
+
+    """
+    if len(rays.shape) == 2:
+        rays = rays.unsqueeze(0)
+    triangles = self.get_triangles()
+    reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+    reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)
+    for triangle in triangles:
+        _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(
+                                                                                    rays,
+                                                                                    triangle
+                                                                                   ) 
+        triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)
+        if triangle_reflected_rays.shape[0] > 0:
+            reflected_rays = torch.cat((
+                                        reflected_rays,
+                                        triangle_reflected_rays
+                                      ))
+            reflected_normals = torch.cat((
+                                           reflected_normals,
+                                           intersecting_normals
+                                          ))
+    return reflected_rays, reflected_normals
+
+
+
+ +
+ +
+ + +

+ save_heights(filename='heights.pt') + +

+ + +
+ +

Function to save heights to a file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
                Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
def save_heights(self, filename = 'heights.pt'):
+    """
+    Function to save heights to a file.
+
+    Parameters
+    ----------
+    filename          : str
+                        Filename.
+    """
+    save_torch_tensor(filename, self.heights.detach().clone())
+
+
+
+ +
+ +
+ + +

+ save_heights_as_PLY(filename='mesh.ply') + +

+ + +
+ +

Function to save mesh to a PLY file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
                Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/mesh.py +
84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
def save_heights_as_PLY(self, filename = 'mesh.ply'):
+    """
+    Function to save mesh to a PLY file.
+
+    Parameters
+    ----------
+    filename          : str
+                        Filename.
+    """
+    triangles = self.get_triangles()
+    write_PLY(triangles, filename)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ center_of_triangle(triangle) + +

+ + +
+ +

Definition to calculate center of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
            An array that contains three points defining a triangle (Mx3). 
    +        It can also parallel process many triangles (NxMx3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +centers ( tensor +) – +
    +

    Triangle centers.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def center_of_triangle(triangle):
+    """
+    Definition to calculate center of a triangle.
+
+    Parameters
+    ----------
+    triangle      : torch.tensor
+                    An array that contains three points defining a triangle (Mx3). 
+                    It can also parallel process many triangles (NxMx3).
+
+    Returns
+    -------
+    centers       : torch.tensor
+                    Triangle centers.
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.view((1, 3, 3))
+    center = torch.mean(triangle, axis=1)
+    return center
+
+
+
+ +
+ +
+ + +

+ define_circle(center, radius, angles) + +

+ + +
+ +

Definition to describe a circle in a single variable packed form.

+ + +

Parameters:

+
    +
  • + center + – +
    +
      Center of a circle to be defined in 3D space.
    +
    +
    +
  • +
  • + radius + – +
    +
      Radius of a circle to be defined.
    +
    +
    +
  • +
  • + angles + – +
    +
      Angular tilt of a circle represented by rotations about x, y, and z axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +circle ( list +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def define_circle(center, radius, angles):
+    """
+    Definition to describe a circle in a single variable packed form.
+
+    Parameters
+    ----------
+    center  : torch.Tensor
+              Center of a circle to be defined in 3D space.
+    radius  : float
+              Radius of a circle to be defined.
+    angles  : torch.Tensor
+              Angular tilt of a circle represented by rotations about x, y, and z axes.
+
+    Returns
+    ----------
+    circle  : list
+              Single variable packed form.
+    """
+    points = define_plane(center, angles=angles)
+    circle = [
+        points,
+        center,
+        torch.tensor([radius])
+    ]
+    return circle
+
+
+
+ +
+ +
+ + +

+ define_plane(point, angles=torch.tensor([0.0, 0.0, 0.0])) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point that is at the center of a plane.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +plane ( tensor +) – +
    +

    Points defining plane.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
def define_plane(point, angles = torch.tensor([0., 0., 0.])):
+    """ 
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    point        : torch.tensor
+                   A point that is at the center of a plane.
+    angles       : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    plane        : torch.tensor
+                   Points defining plane.
+    """
+    plane = torch.tensor([
+                          [10., 10., 0.],
+                          [0., 10., 0.],
+                          [0.,  0., 0.]
+                         ], device = point.device)
+    for i in range(0, plane.shape[0]):
+        plane[i], _, _, _ = rotate_points(plane[i], angles = angles.to(point.device))
+        plane[i] = plane[i] + point
+    return plane
+
+
+
+ +
+ +
+ + +

+ define_plane_mesh(number_of_meshes=[10, 10], size=[1.0, 1.0], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([[0.0, 0.0, 0.0]])) + +

+ + +
+ +

Definition to generate a plane with meshes.

+ + +

Parameters:

+
    +
  • + number_of_meshes + – +
    +
                Number of squares over plane.
    +            There are two triangles at each square.
    +
    +
    +
  • +
  • + size + – +
    +
                Size of the plane.
    +
    +
    +
  • +
  • + angles + – +
    +
                Rotation angles in degrees.
    +
    +
    +
  • +
  • + offset + – +
    +
                Offset along XYZ axes.
    +            Expected dimension is [1 x 3] or offset for each triangle [m x 3].
    +            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +triangles ( tensor +) – +
    +

    Triangles [m x 3 x 3], where m is 2 * number_of_meshes[0] times number_of_meshes[1].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def define_plane_mesh(
+                      number_of_meshes = [10, 10], 
+                      size = [1., 1.], 
+                      angles = torch.tensor([0., 0., 0.]), 
+                      offset = torch.tensor([[0., 0., 0.]])
+                     ):
+    """
+    Definition to generate a plane with meshes.
+
+
+    Parameters
+    -----------
+    number_of_meshes  : torch.tensor
+                        Number of squares over plane.
+                        There are two triangles at each square.
+    size              : list
+                        Size of the plane.
+    angles            : torch.tensor
+                        Rotation angles in degrees.
+    offset            : torch.tensor
+                        Offset along XYZ axes.
+                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].
+                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`. 
+
+    Returns
+    -------
+    triangles         : torch.tensor
+                        Triangles [m x 3 x 3], where m is `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.
+    """
+    triangles = torch.zeros(2, number_of_meshes[0], number_of_meshes[1], 3, 3)
+    step = [size[0] / number_of_meshes[0], size[1] / number_of_meshes[1]]
+    for i in range(0, number_of_meshes[0] - 1):
+        for j in range(0, number_of_meshes[1] - 1):
+            first_triangle = torch.tensor([
+                                           [       -size[0] / 2. + step[0] * i,       -size[1] / 2. + step[0] * j, 0.],
+                                           [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],
+                                           [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]
+                                          ])
+            second_triangle = torch.tensor([
+                                            [ -size[0] / 2. + step[0] * (i + 1), -size[1] / 2. + step[0] * (j + 1), 0.],
+                                            [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],
+                                            [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]
+                                           ])
+            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = angles)
+            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = angles)
+    triangles = triangles.view(-1, 3, 3) + offset
+    return triangles
+
+
+
+ +
+ +
+ + +

+ define_sphere(center=torch.tensor([[0.0, 0.0, 0.0]]), radius=torch.tensor([1.0])) + +

+ + +
+ +

Definition to define a sphere.

+ + +

Parameters:

+
    +
  • + center + – +
    +
          Center of the sphere(s) along XYZ axes.
    +      Expected size is [3], [1, 3] or [m, 3].
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of that sphere(s).
    +      Expected size is [1], [1, 1], [m] or [m, 1].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +parameters ( tensor +) – +
    +

    Parameters of defined sphere(s). +Expected size is [1, 3] or [m x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def define_sphere(center = torch.tensor([[0., 0., 0.]]), radius = torch.tensor([1.])):
+    """
+    Definition to define a sphere.
+
+    Parameters
+    ----------
+    center      : torch.tensor
+                  Center of the sphere(s) along XYZ axes.
+                  Expected size is [3], [1, 3] or [m, 3].
+    radius      : torch.tensor
+                  Radius of that sphere(s).
+                  Expected size is [1], [1, 1], [m] or [m, 1].
+
+    Returns
+    -------
+    parameters  : torch.tensor
+                  Parameters of defined sphere(s).
+                  Expected size is [1, 3] or [m x 3].
+    """
+    if len(radius.shape) == 1:
+        radius = radius.unsqueeze(0)
+    if len(center.shape) == 1:
+        center = center.unsqueeze(1)
+    parameters = torch.cat((center, radius), dim = 1)
+    return parameters
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle(point_to_check, triangle) + +

+ + +
+ +

Definition to check if a given point is inside a triangle. +If the given point is inside a defined triangle, this definition returns True. +For more details, visit: https://blackpawn.com/texts/pointinpoly/.

+ + +

Parameters:

+
    +
  • + point_to_check + – +
    +
              Point(s) to check.
    +          Expected size is [3], [1 x 3] or [m x 3].
    +
    +
    +
  • +
  • + triangle + – +
    +
              Triangle described with three points.
    +          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Is it on a triangle? Returns NaN if condition not satisfied. +Expected size is [1] or [m] depending on the input.

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def is_it_on_triangle(point_to_check, triangle):
+    """
+    Definition to check if a given point is inside a triangle. 
+    If the given point is inside a defined triangle, this definition returns True.
+    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).
+
+    Parameters
+    ----------
+    point_to_check  : torch.tensor
+                      Point(s) to check.
+                      Expected size is [3], [1 x 3] or [m x 3].
+    triangle        : torch.tensor
+                      Triangle described with three points.
+                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].
+
+    Returns
+    -------
+    result          : torch.tensor
+                      Is it on a triangle? Returns NaN if condition not satisfied.
+                      Expected size is [1] or [m] depending on the input.
+    """
+    if len(point_to_check.shape) == 1:
+        point_to_check = point_to_check.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    v0 = triangle[:, 2] - triangle[:, 0]
+    v1 = triangle[:, 1] - triangle[:, 0]
+    v2 = point_to_check - triangle[:, 0]
+    if len(v0.shape) == 1:
+        v0 = v0.unsqueeze(0)
+    if len(v1.shape) == 1:
+        v1 = v1.unsqueeze(0)
+    if len(v2.shape) == 1:
+        v2 = v2.unsqueeze(0)
+    dot00 = torch.mm(v0, v0.T)
+    dot01 = torch.mm(v0, v1.T)
+    dot02 = torch.mm(v0, v2.T) 
+    dot11 = torch.mm(v1, v1.T)
+    dot12 = torch.mm(v1, v2.T)
+    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)
+    u = (dot11 * dot02 - dot01 * dot12) * invDenom
+    v = (dot00 * dot12 - dot01 * dot02) * invDenom
+    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)
+    return result
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle_batch(point_to_check, triangle) + +

+ + +
+ +

Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.

+ + +

Parameters:

+
    +
  • + point_to_check + – +
    +
              Points to check (m x n x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
              Triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( torch.tensor (m x n) +) – +
    + +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/primitives.py +
def is_it_on_triangle_batch(point_to_check, triangle):
+    """
+    Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.
+
+    Parameters
+    ----------
+    point_to_check  : torch.tensor
+                      Points to check (m x n x 3).
+    triangle        : torch.tensor 
+                      Triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    result          : torch.tensor (m x n)
+
+    """
+    if len(point_to_check.shape) == 1:
+        point_to_check = point_to_check.unsqueeze(0)
+    if len(triangle.shape) == 2:
+        triangle = triangle.unsqueeze(0)
+    v0 = triangle[:, 2] - triangle[:, 0]
+    v1 = triangle[:, 1] - triangle[:, 0]
+    v2 = point_to_check - triangle[:, None, 0]
+    if len(v0.shape) == 1:
+        v0 = v0.unsqueeze(0)
+    if len(v1.shape) == 1:
+        v1 = v1.unsqueeze(0)
+    if len(v2.shape) == 1:
+        v2 = v2.unsqueeze(0)
+
+    dot00 = torch.bmm(v0.unsqueeze(1), v0.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot01 = torch.bmm(v0.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot02 = torch.bmm(v0.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)
+    dot11 = torch.bmm(v1.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)
+    dot12 = torch.bmm(v1.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)
+    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)
+    u = (dot11 * dot02 - dot01 * dot12) * invDenom
+    v = (dot00 * dot12 - dot01 * dot02) * invDenom
+    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)
+
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ create_ray(xyz, abg, direction=False) + +

+ + +
+ +

Definition to create a ray.

+ + +

Parameters:

+
    +
  • + xyz + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + abg + – +
    +
           List that contains angles in degrees with respect to the X,Y and Z axes.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + direction + – +
    +
           If set to True, cosines of `abg` is not calculated.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray. +Size will be either [1 x 3] or [m x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
def create_ray(xyz, abg, direction = False):
+    """
+    Definition to create a ray.
+
+    Parameters
+    ----------
+    xyz          : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    abg          : torch.tensor
+                   List that contains angles in degrees with respect to the X,Y and Z axes.
+                   Size could be [1 x 3], [3], [m x 3].
+    direction    : bool
+                   If set to True, cosines of `abg` is not calculated.
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray.
+                   Size will be either [1 x 3] or [m x 3].
+    """
+    points = xyz
+    angles = abg
+    if len(xyz) == 1:
+        points = xyz.unsqueeze(0)
+    if len(abg) == 1:
+        angles = abg.unsqueeze(0)
+    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)
+    ray[:, 0] = points
+    if direction:
+        ray[:, 1] = abg
+    else:
+        ray[:, 1] = torch.cos(torch.deg2rad(abg))
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_all_pairs(x0y0z0, x1y1z1) + +

+ + +
+ +

Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           Tensor that contains X, Y, and Z start locations of rays.
    +       Size should be [m x 3].
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           Tensor that contains X, Y, and Z end locations of rays.
    +       Size should be [n x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rays ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_all_pairs(x0y0z0, x1y1z1):
+    """
+    Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.
+
+    Parameters
+    ----------
+    x0y0z0       : torch.tensor
+                   Tensor that contains X, Y, and Z start locations of rays.
+                   Size should be [m x 3].
+    x1y1z1       : torch.tensor
+                   Tensor that contains X, Y, and Z end locations of rays.
+                   Size should be [n x 3].
+
+    Returns
+    ----------
+    rays         : torch.tensor
+                   Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]
+    """
+
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.unsqueeze(0)
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.unsqueeze(0)
+
+    m, n = x0y0z0.shape[0], x1y1z1.shape[0]
+    start_points = x0y0z0.unsqueeze(1).expand(-1, n, -1).reshape(-1, 3)
+    end_points = x1y1z1.unsqueeze(0).expand(m, -1, -1).reshape(-1, 3)
+
+    directions = end_points - start_points
+    norms = torch.norm(directions, p=2, dim=1, keepdim=True)
+    norms[norms == 0] = float('nan')
+
+    normalized_directions = directions / norms
+
+    rays = torch.zeros(m * n, 2, 3, device=x0y0z0.device)
+    rays[:, 0, :] = start_points
+    rays[:, 1, :] = normalized_directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit) + +

+ + +
+ +

Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.

+ + +
+ Parameters: +

center : torch.tensor + The center point of the light array, shape [3]. +size : list[int] + The size of the light array [height, width] +no : list[int] + The number of the light arary [number of lights in height , number of lights inwidth] +tilt : torch.tensor + The tilt angles in degrees along x, y, z axes for the rays, shape [3]. +angle_limit : float + The maximum angle in degrees from the initial direction vector within which to emit rays. +num_rays_per_light : int + The number of rays each light should emit.

+
+ +
+ Returns: +

rays : torch.tensor + Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

+
+
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit):
+    """
+    Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.
+
+    Parameters:
+    ----------
+    center              : torch.tensor
+                          The center point of the light array, shape [3].
+    size                : list[int]
+                          The size of the light array [height, width]
+    no                  : list[int]
+                          The number of the light arary [number of lights in height , number of lights inwidth]
+    tilt                : torch.tensor
+                          The tilt angles in degrees along x, y, z axes for the rays, shape [3].
+    angle_limit         : float
+                          The maximum angle in degrees from the initial direction vector within which to emit rays.
+    num_rays_per_light  : int
+                          The number of rays each light should emit.
+
+    Returns:
+    ----------
+    rays : torch.tensor
+           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]
+    """
+
+    samples = torch.zeros((no[0], no[1], 3))
+
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((no[0]*no[1], 3))
+
+    samples, *_ = rotate_points(samples, angles=tilt)
+
+    samples = samples + center
+    angle_limit = torch.as_tensor(angle_limit)
+    cos_alpha = torch.cos(angle_limit * torch.pi / 180)
+    tilt = tilt * torch.pi / 180
+
+    theta = torch.acos(1 - 2 * torch.rand(num_ray_per_light*samples.size(0)) * (1-cos_alpha))
+    phi = 2 * torch.pi * torch.rand(num_ray_per_light*samples.size(0))  
+
+    directions = torch.stack([
+        torch.sin(theta) * torch.cos(phi),  
+        torch.sin(theta) * torch.sin(phi),  
+        torch.cos(theta)                    
+    ], dim=1)
+
+    c, s = torch.cos(tilt), torch.sin(tilt)
+
+    Rx = torch.tensor([
+        [1, 0, 0],
+        [0, c[0], -s[0]],
+        [0, s[0], c[0]]
+    ])
+
+    Ry = torch.tensor([
+        [c[1], 0, s[1]],
+        [0, 1, 0],
+        [-s[1], 0, c[1]]
+    ])
+
+    Rz = torch.tensor([
+        [c[2], -s[2], 0],
+        [s[2], c[2], 0],
+        [0, 0, 1]
+    ])
+
+    origins = samples.repeat(num_ray_per_light, 1)
+
+    directions = torch.matmul(directions, (Rz@Ry@Rx).T)
+
+
+    rays = torch.zeros(num_ray_per_light*samples.size(0), 2, 3)
+    rays[:, 0, :] = origins
+    rays[:, 1, :] = directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit) + +

+ + +
+ +

Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.

+ + +
+ Parameters: +

origin : torch.tensor + The origin point of the rays, shape [3]. +num_rays : int + The total number of rays to generate. +tilt : torch.tensor + The tilt angles in degrees along x, y, z axes, shape [3]. +angle_limit : float + The maximum angle in degrees from the initial direction vector within which to emit rays.

+
+ +
+ Returns: +

rays : torch.tensor + Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

+
+
+ Source code in odak/learn/raytracing/ray.py +
def create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit):
+    """
+    Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.
+
+    Parameters:
+    ----------
+    origin      : torch.tensor
+                  The origin point of the rays, shape [3].
+    num_rays    : int
+                  The total number of rays to generate.
+    tilt        : torch.tensor
+                  The tilt angles in degrees along x, y, z axes, shape [3].
+    angle_limit : float
+                  The maximum angle in degrees from the initial direction vector within which to emit rays.
+
+    Returns:
+    ----------
+    rays : torch.tensor
+           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]
+    """
+    angle_limit = torch.as_tensor(angle_limit) 
+    cos_alpha = torch.cos(angle_limit * torch.pi / 180)
+    tilt = tilt * torch.pi / 180
+
+    theta = torch.acos(1 - 2 * torch.rand(num_ray) * (1-cos_alpha))
+    phi = 2 * torch.pi * torch.rand(num_ray)  
+
+
+    directions = torch.stack([
+        torch.sin(theta) * torch.cos(phi),  
+        torch.sin(theta) * torch.sin(phi),  
+        torch.cos(theta)                    
+    ], dim=1)
+
+    c, s = torch.cos(tilt), torch.sin(tilt)
+
+    Rx = torch.tensor([
+        [1, 0, 0],
+        [0, c[0], -s[0]],
+        [0, s[0], c[0]]
+    ])
+
+    Ry = torch.tensor([
+        [c[1], 0, s[1]],
+        [0, 1, 0],
+        [-s[1], 0, c[1]]
+    ])
+
+    Rz = torch.tensor([
+        [c[2], -s[2], 0],
+        [s[2], c[2], 0],
+        [0, 0, 1]
+    ])
+
+    origins = origin.repeat(num_ray, 1)
+    directions = torch.matmul(directions, (Rz@Ry@Rx).T)
+
+
+    rays = torch.zeros(num_ray, 2, 3)
+    rays[:, 0, :] = origins
+    rays[:, 1, :] = directions
+
+    return rays
+
+
+
+ +
+ +
+ + +

+ create_ray_from_two_points(x0y0z0, x1y1z1) + +

+ + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray or batch of rays.
    +       Size could be [1 x 3], [3], [m x 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray(s).

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : torch.tensor
+                   List that contains X,Y and Z start locations of a ray.
+                   Size could be [1 x 3], [3], [m x 3].
+    x1y1z1       : torch.tensor
+                   List that contains X,Y and Z ending locations of a ray or batch of rays.
+                   Size could be [1 x 3], [3], [m x 3].
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray(s).
+    """
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.unsqueeze(0)
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.unsqueeze(0)
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5
+    s[s == 0] = float('nan')
+    cosines = torch.zeros_like(x0y0z0 * x1y1z1)
+    cosines[:, 0] = xdiff / s
+    cosines[:, 1] = ydiff / s
+    cosines[:, 2] = zdiff / s
+    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    return ray
+
+
+
+ +
+ +
+ + +

+ propagate_ray(ray, distance) + +

+ + +
+ +

Definition to propagate a ray at a certain given distance.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
    +
    +
    +
  • +
  • + distance + – +
    +
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_ray ( tensor +) – +
    +

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/raytracing/ray.py +
def propagate_ray(ray, distance):
+    """
+    Definition to propagate a ray at a certain given distance.
+
+    Parameters
+    ----------
+    ray        : torch.tensor
+                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].
+    distance   : torch.tensor
+                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].
+
+    Returns
+    ----------
+    new_ray    : torch.tensor
+                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].
+    """
+    if len(ray.shape) == 2:
+        ray = ray.unsqueeze(0)
+    if len(distance.shape) == 2:
+        distance = distance.squeeze(-1)
+    new_ray = torch.zeros_like(ray)
+    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]
+    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]
+    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]
+    return new_ray
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_tools/index.html b/odak/learn_tools/index.html new file mode 100644 index 00000000..6f614c77 --- /dev/null +++ b/odak/learn_tools/index.html @@ -0,0 +1,10436 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.tools - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.tools

+ +
+ + + + +
+ +

odak.learn.tools

+

Provides necessary definitions for general tools used across the library.

+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3], padding='same') + +

+ + +
+ +

A definition to blur a field using a Gaussian kernel.

+ + +

Parameters:

+
    +
  • + field + – +
    +
            MxN field.
    +
    +
    +
  • +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + padding + – +
    +
            Padding value, see torch.nn.functional.conv2d() for more.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +blurred_field ( tensor +) – +
    +

    Blurred field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'):
+    """
+    A definition to blur a field using a Gaussian kernel.
+
+    Parameters
+    ----------
+    field         : torch.tensor
+                    MxN field.
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+    padding       : int or string
+                    Padding value, see torch.nn.functional.conv2d() for more.
+
+    Returns
+    ----------
+    blurred_field : torch.tensor
+                    Blurred field.
+    """
+    kernel = generate_2d_gaussian(kernel_length, nsigma).to(field.device)
+    kernel = kernel.unsqueeze(0).unsqueeze(0)
+    if len(field.shape) == 2:
+        field = field.view(1, 1, field.shape[-2], field.shape[-1])
+    blurred_field = torch.nn.functional.conv2d(field, kernel, padding='same')
+    if field.shape[1] == 1:
+        blurred_field = blurred_field.view(
+                                           blurred_field.shape[-2],
+                                           blurred_field.shape[-1]
+                                          )
+    return blurred_field
+
+
+
+ +
+ +
+ + +

+ circular_binary_mask(px, py, r) + +

+ + +
+ +

Definition to generate a 2D circular binary mask.

+ + +
+ Parameter +

px : int + Pixel count in x. +py : int + Pixel count in y. +r : int + Radius of the circle.

+
+ +

Returns:

+
    +
  • +mask ( tensor +) – +
    +

    Mask [1 x 1 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/mask.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
def circular_binary_mask(px, py, r):
+    """
+    Definition to generate a 2D circular binary mask.
+
+    Parameter
+    ---------
+    px           : int
+                   Pixel count in x.
+    py           : int
+                   Pixel count in y.
+    r            : int
+                   Radius of the circle.
+
+    Returns
+    -------
+    mask         : torch.tensor
+                   Mask [1 x 1 x m x n].
+    """
+    x = torch.linspace(-px / 2., px / 2., px)
+    y = torch.linspace(-py / 2., py / 2., py)
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    Z = (X ** 2 + Y ** 2) ** 0.5
+    mask = torch.zeros_like(Z)
+    mask[Z < r] = 1
+    return mask
+
+
+
+ +
+ +
+ + +

+ convolve2d(field, kernel) + +

+ + +
+ +

Definition to convolve a field with a kernel by multiplying in frequency space.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field with MxN shape.
    +
    +
    +
  • +
  • + kernel + – +
    +
          Input kernel with MxN shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( tensor +) – +
    +

    Convolved field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def convolve2d(field, kernel):
+    """
+    Definition to convolve a field with a kernel by multiplying in frequency space.
+
+    Parameters
+    ----------
+    field       : torch.tensor
+                  Input field with MxN shape.
+    kernel      : torch.tensor
+                  Input kernel with MxN shape.
+
+    Returns
+    ----------
+    new_field   : torch.tensor
+                  Convolved field.
+    """
+    fr = torch.fft.fft2(field)
+    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))
+    m, n = fr.shape
+    new_field = torch.real(torch.fft.ifft2(fr*fr2))
+    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))
+    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))
+    return new_field
+
+
+
+ +
+ +
+ + +

+ correlation_2d(first_tensor, second_tensor) + +

+ + +
+ +

Definition to calculate the correlation between two tensors.

+ + +

Parameters:

+
    +
  • + first_tensor + – +
    +
            First tensor.
    +
    +
    +
  • +
  • + second_tensor + (tensor) + – +
    +
            Second tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +correlation ( tensor +) – +
    +

    Correlation between the two tensors.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def correlation_2d(first_tensor, second_tensor):
+    """
+    Definition to calculate the correlation between two tensors.
+
+    Parameters
+    ----------
+    first_tensor  : torch.tensor
+                    First tensor.
+    second_tensor : torch.tensor
+                    Second tensor.
+
+    Returns
+    ----------
+    correlation   : torch.tensor
+                    Correlation between the two tensors.
+    """
+    fft_first_tensor = (torch.fft.fft2(first_tensor))
+    fft_second_tensor = (torch.fft.fft2(second_tensor))
+    conjugate_second_tensor = torch.conj(fft_second_tensor)
+    result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor))
+    return result
+
+
+
+ +
+ +
+ + +

+ crop_center(field, size=None) + +

+ + +
+ +

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.
    +
    +
    +
  • +
  • + size + – +
    +
          Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cropped ( ndarray +) – +
    +

    Cropped version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def crop_center(field, size = None):
+    """
+    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.
+    size        : list
+                  Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).
+
+    Returns
+    ----------
+    cropped     : ndarray
+                  Cropped version of the input field.
+    """
+    orig_resolution = field.shape
+    if len(field.shape) < 3:
+        field = field.unsqueeze(0)
+    if len(field.shape) < 4:
+        field = field.unsqueeze(0)
+    permute_flag = False
+    if field.shape[-1] < 5:
+        permute_flag = True
+        field = field.permute(0, 3, 1, 2)
+    if type(size) == type(None):
+        qx = int(field.shape[-2] // 4)
+        qy = int(field.shape[-1] // 4)
+        cropped_padded = field[:, :, qx: qx + field.shape[-2] // 2, qy:qy + field.shape[-1] // 2]
+    else:
+        cx = int(field.shape[-2] // 2)
+        cy = int(field.shape[-1] // 2)
+        hx = int(size[-2] // 2)
+        hy = int(size[-1] // 2)
+        cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy]
+    cropped = cropped_padded
+    if permute_flag:
+        cropped = cropped.permute(0, 2, 3, 1)
+    if len(orig_resolution) == 2:
+        cropped = cropped_padded.squeeze(0).squeeze(0)
+    if len(orig_resolution) == 3:
+        cropped = cropped_padded.squeeze(0)
+    return cropped
+
+
+
+ +
+ +
+ + +

+ cross_product(vector1, vector2) + +

+ + +
+ +

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

+ + +

Parameters:

+
    +
  • + vector1 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + vector2 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def cross_product(vector1, vector2):
+    """
+    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product
+
+    Parameters
+    ----------
+    vector1      : torch.tensor
+                   A vector/ray.
+    vector2      : torch.tensor
+                   A vector/ray.
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray.
+    """
+    angle = torch.cross(vector1[1].T, vector2[1].T)
+    angle = torch.tensor(angle)
+    ray = torch.tensor([vector1[0], angle], dtype=torch.float32)
+    return ray
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( Tensor +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : torch.Tensor
+                  First point in X,Y,Z.
+    point2      : torch.Tensor
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : torch.Tensor
+                  Distance in between given two points.
+    """
+    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1
+    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2
+
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))
+
+    return distance
+
+
+
+ +
+ +
+ + +

+ expanduser(filename) + +

+ + +
+ +

Definition to decode filename using namespaces and shortcuts.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_filename ( str +) – +
    +

    Filename.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def expanduser(filename):
+    """
+    Definition to decode filename using namespaces and shortcuts.
+
+
+    Parameters
+    ----------
+    filename      : str
+                    Filename.
+
+
+    Returns
+    -------
+    new_filename  : str
+                    Filename.
+    """
+    new_filename = os.path.expanduser(filename)
+    return new_filename
+
+
+
+ +
+ +
+ + +

+ generate_2d_dirac_delta(kernel_length=[21, 21], a=[3, 3], mu=[0, 0], theta=0, normalize=False) + +

+ + +
+ +

Generate 2D Dirac delta function by using Gaussian distribution. +Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Dirac delta function along X and Y axes.
    +
    +
    +
  • +
  • + a + – +
    +
            The scale factor in Gaussian distribution to approximate the Dirac delta function. 
    +        As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.
    +
    +
    +
  • +
  • + mu + – +
    +
            Mu of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + theta + – +
    +
            The rotation angle of the 2D Dirac delta function.
    +
    +
    +
  • +
  • + normalize + – +
    +
            If set True, normalize the output.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( tensor +) – +
    +

    Generated 2D Dirac delta function.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def generate_2d_dirac_delta(
+                            kernel_length = [21, 21],
+                            a = [3, 3],
+                            mu = [0, 0],
+                            theta = 0,
+                            normalize = False
+                           ):
+    """
+    Generate 2D Dirac delta function by using Gaussian distribution.
+    Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Dirac delta function along X and Y axes.
+    a             : list
+                    The scale factor in Gaussian distribution to approximate the Dirac delta function. 
+                    As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.
+    mu            : list
+                    Mu of the Gaussian kernel along X and Y axes.
+    theta         : float
+                    The rotation angle of the 2D Dirac delta function.
+    normalize     : bool
+                    If set True, normalize the output.
+
+    Returns
+    ----------
+    kernel_2d     : torch.tensor
+                    Generated 2D Dirac delta function.
+    """
+    x = torch.linspace(-kernel_length[0] / 2., kernel_length[0] / 2., kernel_length[0])
+    y = torch.linspace(-kernel_length[1] / 2., kernel_length[1] / 2., kernel_length[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    X = X - mu[0]
+    Y = Y - mu[1]
+    theta = torch.as_tensor(theta)
+    X_rot = X * torch.cos(theta) - Y * torch.sin(theta)
+    Y_rot = X * torch.sin(theta) + Y * torch.cos(theta)
+    kernel_2d = (1 / (abs(a[0] * a[1]) * torch.pi)) * torch.exp(-((X_rot / a[0]) ** 2 + (Y_rot / a[1]) ** 2))
+    if normalize:
+        kernel_2d = kernel_2d / kernel_2d.max()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3], mu=[0, 0], normalize=False) + +

+ + +
+ +

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + mu + – +
    +
            Mu of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + normalize + – +
    +
            If set True, normalize the output.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( tensor +) – +
    +

    Generated Gaussian kernel.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):
+    """
+    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+    mu            : list
+                    Mu of the Gaussian kernel along X and Y axes.
+    normalize     : bool
+                    If set True, normalize the output.
+
+    Returns
+    ----------
+    kernel_2d     : torch.tensor
+                    Generated Gaussian kernel.
+    """
+    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])
+    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    if nsigma[0] == 0:
+        nsigma[0] = 1e-5
+    if nsigma[1] == 0:
+        nsigma[1] = 1e-5
+    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))
+    if normalize:
+        kernel_2d = kernel_2d / kernel_2d.max()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ get_rotation_matrix(tilt_angles=[0.0, 0.0, 0.0], tilt_order='XYZ') + +

+ + +
+ +

Function to generate rotation matrix for given tilt angles and tilt order.

+ + +

Parameters:

+
    +
  • + tilt_angles + – +
    +
                 Tilt angles in degrees along XYZ axes.
    +
    +
    +
  • +
  • + tilt_order + – +
    +
                 Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotmat ( tensor +) – +
    +

    Rotation matrix.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def get_rotation_matrix(tilt_angles = [0., 0., 0.], tilt_order = 'XYZ'):
+    """
+    Function to generate rotation matrix for given tilt angles and tilt order.
+
+
+    Parameters
+    ----------
+    tilt_angles        : list
+                         Tilt angles in degrees along XYZ axes.
+    tilt_order         : str
+                         Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).
+
+    Returns
+    -------
+    rotmat             : torch.tensor
+                         Rotation matrix.
+    """
+    rotx = rotmatx(tilt_angles[0])
+    roty = rotmaty(tilt_angles[1])
+    rotz = rotmatz(tilt_angles[2])
+    if tilt_order =='XYZ':
+        rotmat = torch.mm(rotz,torch.mm(roty, rotx))
+    elif tilt_order == 'XZY':
+        rotmat = torch.mm(roty,torch.mm(rotz, rotx))
+    elif tilt_order == 'ZXY':
+        rotmat = torch.mm(roty,torch.mm(rotx, rotz))
+    elif tilt_order == 'YXZ':
+        rotmat = torch.mm(rotz,torch.mm(rotx, roty))
+    elif tilt_order == 'ZYX':
+         rotmat = torch.mm(rotx,torch.mm(roty, rotz))
+    return rotmat
+
+
+
+ +
+ +
+ + +

+ grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( tensor +) – +
    +

    Samples generated.

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix at X axis.

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix at Y axis.

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix at Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/sample.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
def grid_sample(
+                no = [10, 10],
+                size = [100., 100.], 
+                center = [0., 0., 0.], 
+                angles = [0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    -------
+    samples     : torch.tensor
+                  Samples generated.
+    rotx        : torch.tensor
+                  Rotation matrix at X axis.
+    roty        : torch.tensor
+                  Rotation matrix at Y axis.
+    rotz        : torch.tensor
+                  Rotation matrix at Z axis.
+    """
+    center = torch.tensor(center)
+    angles = torch.tensor(angles)
+    size = torch.tensor(size)
+    samples = torch.zeros((no[0], no[1], 3))
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))
+    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)
+    return samples, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ histogram_loss(frame, ground_truth, bins=32, limits=[0.0, 1.0]) + +

+ + +
+ +

Function for evaluating a frame against a target using histogram.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
               Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + ground_truth + – +
    +
               Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].
    +
    +
    +
  • +
  • + bins + – +
    +
               Number of bins.
    +
    +
    +
  • +
  • + limits + – +
    +
               Limits.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]):
+    """
+    Function for evaluating a frame against a target using histogram.
+
+    Parameters
+    ----------
+    frame            : torch.tensor
+                       Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    ground_truth     : torch.tensor
+                       Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].
+    bins             : int
+                       Number of bins.
+    limits           : list
+                       Limits.
+
+    Returns
+    -------
+    loss             : float
+                       Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0).unsqueeze(0)
+    elif len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+
+    if len(ground_truth.shape) == 2:
+        ground_truth = ground_truth.unsqueeze(0).unsqueeze(0)
+    elif len(ground_truth.shape) == 3:
+        ground_truth = ground_truth.unsqueeze(0)
+
+    histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device)
+    histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device)
+
+    l2 = torch.nn.MSELoss()
+
+    for i in range(frame.shape[1]):
+        histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])
+        histogram_ground_truth[i] = torch.histc(ground_truth[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])
+
+    loss = l2(histogram_frame, histogram_ground_truth)
+
+    return loss
+
+
+
+ +
+ +
+ + +

+ load_image(fn, normalizeby=0.0, torch_style=False) + +

+ + +
+ +

Definition to load an image from a given location as a torch tensor.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + normalizeby + – +
    +
           Value to to normalize images with. Default value of zero will lead to no normalization.
    +
    +
    +
  • +
  • + torch_style + – +
    +
           If set True, it will load an image mxnx3 as 3xmxn.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( ndarray +) – +
    +

    Image loaded as a Numpy array.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
def load_image(fn, normalizeby = 0., torch_style = False):
+    """
+    Definition to load an image from a given location as a torch tensor.
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    normalizeby  : float or optional
+                   Value to to normalize images with. Default value of zero will lead to no normalization.
+    torch_style  : bool or optional
+                   If set True, it will load an image mxnx3 as 3xmxn.
+
+    Returns
+    -------
+    image        :  ndarray
+                    Image loaded as a Numpy array.
+
+    """
+    image = odak.tools.load_image(fn, normalizeby = normalizeby, torch_style = torch_style)
+    image = torch.from_numpy(image).float()
+    return image
+
+
+
+ +
+ +
+ + +

+ michelson_contrast(image, roi_high, roi_low) + +

+ + +
+ +

A function to calculate michelson contrast ratio of given region of interests of the image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
  • + roi_high + – +
    +
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
  • + roi_low + – +
    +
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Michelson contrast for the given regions. [1] or [3] depending on input image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def michelson_contrast(image, roi_high, roi_low):
+    """
+    A function to calculate michelson contrast ratio of given region of interests of the image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].
+    roi_high      : torch.tensor
+                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
+    roi_low       : torch.tensor
+                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Michelson contrast for the given regions. [1] or [3] depending on input image.
+    """
+    if len(image.shape) == 2:
+        image = image.unsqueeze(0)
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]
+    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]
+    high = torch.mean(region_high, dim = (2, 3))
+    low = torch.mean(region_low, dim = (2, 3))
+    result = (high - low) / (high + low)
+    return result.squeeze(0)
+
+
+
+ +
+ +
+ + +

+ multi_scale_total_variation_loss(frame, levels=3) + +

+ + +
+ +

Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
  • + levels + – +
    +
            Number of levels to go in the image pyriamid.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def multi_scale_total_variation_loss(frame, levels = 3):
+    """
+    Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.
+
+    Parameters
+    ----------
+    frame         : torch.tensor
+                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
+    levels        : int
+                    Number of levels to go in the image pyriamid.
+
+    Returns
+    -------
+    loss          : float
+                    Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0)
+    if len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+    scale = torch.nn.Upsample(scale_factor = 0.5, mode = 'nearest')
+    level = frame
+    loss = 0
+    for i in range(levels):
+        if i != 0:
+           level = scale(level)
+        loss += total_variation_loss(level) 
+    return loss
+
+
+
+ +
+ +
+ + +

+ quantize(image_field, bits=8, limits=[0.0, 1.0]) + +

+ + +
+ +

Definition to quantize a image field (0-255, 8 bit) to a certain bits level.

+ + +

Parameters:

+
    +
  • + image_field + (tensor) + – +
    +
          Input image field between any range.
    +
    +
    +
  • +
  • + bits + – +
    +
          A value in between one to eight.
    +
    +
    +
  • +
  • + limits + – +
    +
          The minimum and maximum of the image_field variable.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( tensor +) – +
    +

    Quantized image field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
def quantize(image_field, bits = 8, limits = [0., 1.]):
+    """ 
+    Definition to quantize a image field (0-255, 8 bit) to a certain bits level.
+
+    Parameters
+    ----------
+    image_field : torch.tensor
+                  Input image field between any range.
+    bits        : int
+                  A value in between one to eight.
+    limits      : list
+                  The minimum and maximum of the image_field variable.
+
+    Returns
+    ----------
+    new_field   : torch.tensor
+                  Quantized image field.
+    """
+    normalized_field = (image_field - limits[0]) / (limits[1] - limits[0])
+    divider = 2 ** bits
+    new_field = normalized_field * divider
+    new_field = new_field.int()
+    return new_field
+
+
+
+ +
+ +
+ + +

+ radial_basis_function(value, epsilon=0.5) + +

+ + +
+ +

Function to pass a value into radial basis function with Gaussian description.

+ + +

Parameters:

+
    +
  • + value + – +
    +
               Value(s) to pass to the radial basis function.
    +
    +
    +
  • +
  • + epsilon + – +
    +
               Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output values.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def radial_basis_function(value, epsilon = 0.5):
+    """
+    Function to pass a value into radial basis function with Gaussian description.
+
+    Parameters
+    ----------
+    value            : torch.tensor
+                       Value(s) to pass to the radial basis function. 
+    epsilon          : float
+                       Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).
+
+    Returns
+    -------
+    output           : torch.tensor
+                       Output values.
+    """
+    output = torch.exp((-(epsilon * value)**2))
+    return output
+
+
+
+ +
+ +
+ + +

+ resize(image, multiplier=0.5, mode='nearest') + +

+ + +
+ +

Definition to resize an image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image with MxNx3 resolution.
    +
    +
    +
  • +
  • + multiplier + – +
    +
            Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).
    +
    +
    +
  • +
  • + mode + – +
    +
            Mode to be used in scaling, nearest, bilinear, etc.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_image ( tensor +) – +
    +

    Resized image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
def resize(image, multiplier = 0.5, mode = 'nearest'):
+    """
+    Definition to resize an image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image with MxNx3 resolution.
+    multiplier    : float
+                    Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).
+    mode          : str
+                    Mode to be used in scaling, nearest, bilinear, etc.
+
+    Returns
+    -------
+    new_image     : torch.tensor
+                    Resized image.
+
+    """
+    scale = torch.nn.Upsample(scale_factor = multiplier, mode = mode)
+    new_image = torch.zeros((int(image.shape[0] * multiplier), int(image.shape[1] * multiplier), 3)).to(image.device)
+    for i in range(3):
+        cache = image[:,:,i].unsqueeze(0)
+        cache = cache.unsqueeze(0)
+        new_cache = scale(cache).unsqueeze(0)
+        new_image[:,:,i] = new_cache.unsqueeze(0)
+    return new_image
+
+
+
+ +
+ +
+ + +

+ rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]])) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point with size of [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis.
    +       There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +       Expected size is [3] or [1, 3].
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +       Expected size is [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the rotation [1 x 3] or [m x 3].

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix along X axis [3 x 3].

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix along Y axis [3 x 3].

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix along Z axis [3 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def rotate_points(
+                 point,
+                 angles = torch.tensor([[0, 0, 0]]), 
+                 mode='XYZ', 
+                 origin = torch.tensor([[0, 0, 0]]), 
+                 offset = torch.tensor([[0, 0, 0]])
+                ):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : torch.tensor
+                   A point with size of [3] or [1, 3] or [m, 3].
+    angles       : torch.tensor
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis.
+                   There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : torch.tensor
+                   Reference point for a rotation.
+                   Expected size is [3] or [1, 3].
+    offset       : torch.tensor
+                   Shift with the given offset.
+                   Expected size is [3] or [1, 3] or [m, 3].
+
+    Returns
+    ----------
+    result       : torch.tensor
+                   Result of the rotation [1 x 3] or [m x 3].
+    rotx         : torch.tensor
+                   Rotation matrix along X axis [3 x 3].
+    roty         : torch.tensor
+                   Rotation matrix along Y axis [3 x 3].
+    rotz         : torch.tensor
+                   Rotation matrix along Z axis [3 x 3].
+    """
+    origin = origin.to(point.device)
+    offset = offset.to(point.device)
+    if len(point.shape) == 1:
+        point = point.unsqueeze(0)
+    if len(angles.shape) == 1:
+        angles = angles.unsqueeze(0)
+    rotx = rotmatx(angles[:, 0])
+    roty = rotmaty(angles[:, 1])
+    rotz = rotmatz(angles[:, 2])
+    new_point = (point - origin).T
+    if mode == 'XYZ':
+        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T
+    elif mode == 'XZY':
+        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T
+    elif mode == 'YXZ':
+        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T
+    elif mode == 'ZXY':
+        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T
+    elif mode == 'ZYX':
+        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T
+    result += origin
+    result += offset
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ rotmatx(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotx ( tensor +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
def rotmatx(angle):
+    """
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    rotx         : torch.tensor
+                   Rotation matrix along X axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    rotx = torch.stack([
+                        torch.stack([ one,              zero,              zero]),
+                        torch.stack([zero,  torch.cos(angle), -torch.sin(angle)]),
+                        torch.stack([zero,  torch.sin(angle),  torch.cos(angle)])
+                       ]).reshape(3, 3)
+    return rotx
+
+
+
+ +
+ +
+ + +

+ rotmaty(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Y axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +roty ( tensor +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def rotmaty(angle):
+    """
+    Definition to generate a rotation matrix along Y axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    roty         : torch.tensor
+                   Rotation matrix along Y axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    roty = torch.stack([
+                        torch.stack([ torch.cos(angle), zero, torch.sin(angle)]),
+                        torch.stack([             zero,  one,             zero]),
+                        torch.stack([-torch.sin(angle), zero, torch.cos(angle)])
+                       ]).reshape(3, 3)
+    return roty
+
+
+
+ +
+ +
+ + +

+ rotmatz(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Z axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotz ( tensor +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def rotmatz(angle):
+    """
+    Definition to generate a rotation matrix along Z axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    rotz         : torch.tensor
+                   Rotation matrix along Z axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    rotz = torch.stack([
+                        torch.stack([torch.cos(angle), -torch.sin(angle), zero]),
+                        torch.stack([torch.sin(angle),  torch.cos(angle), zero]),
+                        torch.stack([            zero,              zero,  one])
+                       ]).reshape(3,3)
+    return rotz
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = torch.subtract(b, a)
+    p1a = torch.subtract(p1, a)
+    p2a = torch.subtract(p2, a)
+    cp1 = torch.cross(ba, p1a)
+    cp2 = torch.cross(ba, p2a)
+    test = torch.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ save_image(fn, img, cmin=0, cmax=255, color_depth=8) + +

+ + +
+ +

Definition to save a torch tensor as an image.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + img + – +
    +
           A numpy array with NxMx3 or NxMx1 shapes.
    +
    +
    +
  • +
  • + cmin + – +
    +
           Minimum value that will be interpreted as 0 level in the final image.
    +
    +
    +
  • +
  • + cmax + – +
    +
           Maximum value that will be interpreted as 255 level in the final image.
    +
    +
    +
  • +
  • + color_depth + – +
    +
           Color depth of an image. Default is eight.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +bool ( bool +) – +
    +

    True if successful.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):
+    """
+    Definition to save a torch tensor as an image.
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    img          : ndarray
+                   A numpy array with NxMx3 or NxMx1 shapes.
+    cmin         : int
+                   Minimum value that will be interpreted as 0 level in the final image.
+    cmax         : int
+                   Maximum value that will be interpreted as 255 level in the final image.
+    color_depth  : int
+                   Color depth of an image. Default is eight.
+
+
+    Returns
+    ----------
+    bool         :  bool
+                    True if successful.
+
+    """
+    if len(img.shape) ==  4:
+        img = img.squeeze(0)
+    if len(img.shape) > 2 and torch.argmin(torch.tensor(img.shape)) == 0:
+        new_img = torch.zeros(img.shape[1], img.shape[2], img.shape[0]).to(img.device)
+        for i in range(img.shape[0]):
+            new_img[:, :, i] = img[i].detach().clone()
+        img = new_img.detach().clone()
+    img = img.cpu().detach().numpy()
+    return odak.tools.save_image(fn, img, cmin = cmin, cmax = cmax, color_depth = color_depth)
+
+
+
+ +
+ +
+ + +

+ save_torch_tensor(fn, tensor) + +

+ + +
+ +

Definition to save a torch tensor.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + tensor + – +
    +
           Torch tensor to be saved.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
def save_torch_tensor(fn, tensor):
+    """
+    Definition to save a torch tensor.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    tensor       : torch.tensor
+                   Torch tensor to be saved.
+    """ 
+    torch.save(tensor, expanduser(fn))
+
+
+
+ +
+ +
+ + +

+ tilt_towards(location, lookat) + +

+ + +
+ +

Definition to tilt surface normal of a plane towards a point.

+ + +

Parameters:

+
    +
  • + location + – +
    +
           Center of the plane to be tilted.
    +
    +
    +
  • +
  • + lookat + – +
    +
           Tilt towards this point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +angles ( list +) – +
    +

    Rotation angles in degrees.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def tilt_towards(location, lookat):
+    """
+    Definition to tilt surface normal of a plane towards a point.
+
+    Parameters
+    ----------
+    location     : list
+                   Center of the plane to be tilted.
+    lookat       : list
+                   Tilt towards this point.
+
+    Returns
+    ----------
+    angles       : list
+                   Rotation angles in degrees.
+    """
+    dx = location[0] - lookat[0]
+    dy = location[1] - lookat[1]
+    dz = location[2] - lookat[2]
+    dist = torch.sqrt(torch.tensor(dx ** 2 + dy ** 2 + dz ** 2))
+    phi = torch.atan2(torch.tensor(dy), torch.tensor(dx))
+    theta = torch.arccos(dz / dist)
+    angles = [0, float(torch.rad2deg(theta)), float(torch.rad2deg(phi))]
+    return angles
+
+
+
+ +
+ +
+ + +

+ torch_load(fn) + +

+ + +
+ +

Definition to load a torch files (*.pt).

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +data ( any +) – +
    +

    See torch.load() for more.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
def torch_load(fn):
+    """
+    Definition to load a torch files (*.pt).
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+
+    Returns
+    -------
+    data         : any
+                   See torch.load() for more.
+    """  
+    data = torch.load(expanduser(fn))
+    return data
+
+
+
+ +
+ +
+ + +

+ total_variation_loss(frame) + +

+ + +
+ +

Function for evaluating a frame against a target using total variation approach.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
def total_variation_loss(frame):
+    """
+    Function for evaluating a frame against a target using total variation approach.
+
+    Parameters
+    ----------
+    frame         : torch.tensor
+                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
+
+    Returns
+    -------
+    loss          : float
+                    Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0)
+    if len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+    diff_x = frame[:, :, :, 1:] - frame[:, :, :, :-1]
+    diff_y = frame[:, :, 1:, :] - frame[:, :, :-1, :]
+    pixel_count = frame.shape[0] * frame.shape[1] * frame.shape[2] * frame.shape[3]
+    loss = ((diff_x ** 2).sum() + (diff_y ** 2).sum()) / pixel_count
+    return loss
+
+
+
+ +
+ +
+ + +

+ weber_contrast(image, roi_high, roi_low) + +

+ + +
+ +

A function to calculate weber contrast ratio of given region of interests of the image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + roi_high + – +
    +
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
  • + roi_low + – +
    +
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Weber contrast for given regions. [1] or [3] depending on input image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def weber_contrast(image, roi_high, roi_low):
+    """
+    A function to calculate weber contrast ratio of given region of interests of the image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].
+    roi_high      : torch.tensor
+                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
+    roi_low       : torch.tensor
+                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Weber contrast for given regions. [1] or [3] depending on input image.
+    """
+    if len(image.shape) == 2:
+        image = image.unsqueeze(0)
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]
+    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]
+    high = torch.mean(region_high, dim = (2, 3))
+    low = torch.mean(region_low, dim = (2, 3))
+    result = (high - low) / low
+    return result.squeeze(0)
+
+
+
+ +
+ +
+ + +

+ wrapped_mean_squared_error(image, ground_truth, reduction='mean') + +

+ + +
+ +

A function to calculate the wrapped mean squared error between predicted and target angles.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + ground_truth + – +
    +
            Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + reduction + – +
    +
            Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +wmse ( tensor +) – +
    +

    The calculated wrapped mean squared error.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def wrapped_mean_squared_error(image, ground_truth, reduction = 'mean'):
+    """
+    A function to calculate the wrapped mean squared error between predicted and target angles.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    ground_truth  : torch.tensor
+                    Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    reduction     : str
+                    Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.
+
+    Returns
+    -------
+    wmse        : torch.tensor
+                  The calculated wrapped mean squared error. 
+    """
+    sin_diff = torch.sin(image) - torch.sin(ground_truth)
+    cos_diff = torch.cos(image) - torch.cos(ground_truth)
+    loss = (sin_diff**2 + cos_diff**2)
+
+    if reduction == 'mean':
+        return loss.mean()
+    elif reduction == 'sum':
+        return loss.sum()
+    else:
+        raise ValueError("Invalid reduction type. Choose 'mean' or 'sum'.")
+
+
+
+ +
+ +
+ + +

+ zero_pad(field, size=None, method='center') + +

+ + +
+ +

Definition to zero pad a MxN array to 2Mx2N array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
                Input field MxN or KxJxMxN or KxMxNxJ array.
    +
    +
    +
  • +
  • + size + – +
    +
                Size to be zeropadded (e.g., [m, n], last two dimensions only).
    +
    +
    +
  • +
  • + method + – +
    +
                Zeropad either by placing the content to center or to the left.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field_zero_padded ( ndarray +) – +
    +

    Zeropadded version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
def zero_pad(field, size = None, method = 'center'):
+    """
+    Definition to zero pad a MxN array to 2Mx2N array.
+
+    Parameters
+    ----------
+    field             : ndarray
+                        Input field MxN or KxJxMxN or KxMxNxJ array.
+    size              : list
+                        Size to be zeropadded (e.g., [m, n], last two dimensions only).
+    method            : str
+                        Zeropad either by placing the content to center or to the left.
+
+    Returns
+    ----------
+    field_zero_padded : ndarray
+                        Zeropadded version of the input field.
+    """
+    orig_resolution = field.shape
+    if len(field.shape) < 3:
+        field = field.unsqueeze(0)
+    if len(field.shape) < 4:
+        field = field.unsqueeze(0)
+    permute_flag = False
+    if field.shape[-1] < 5:
+        permute_flag = True
+        field = field.permute(0, 3, 1, 2)
+    if type(size) == type(None):
+        resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]]
+    else:
+        resolution = [field.shape[0], field.shape[1], size[0], size[1]]
+    field_zero_padded = torch.zeros(resolution, device = field.device, dtype = field.dtype)
+    if method == 'center':
+       start = [
+                resolution[-2] // 2 - field.shape[-2] // 2,
+                resolution[-1] // 2 - field.shape[-1] // 2
+               ]
+       field_zero_padded[
+                         :, :,
+                         start[0] : start[0] + field.shape[-2],
+                         start[1] : start[1] + field.shape[-1]
+                         ] = field
+    elif method == 'left':
+       field_zero_padded[
+                         :, :,
+                         0: field.shape[-2],
+                         0: field.shape[-1]
+                        ] = field
+    if permute_flag == True:
+        field_zero_padded = field_zero_padded.permute(0, 2, 3, 1)
+    if len(orig_resolution) == 2:
+        field_zero_padded = field_zero_padded.squeeze(0).squeeze(0)
+    if len(orig_resolution) == 3:
+        field_zero_padded = field_zero_padded.squeeze(0)
+    return field_zero_padded
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ load_image(fn, normalizeby=0.0, torch_style=False) + +

+ + +
+ +

Definition to load an image from a given location as a torch tensor.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + normalizeby + – +
    +
           Value to to normalize images with. Default value of zero will lead to no normalization.
    +
    +
    +
  • +
  • + torch_style + – +
    +
           If set True, it will load an image mxnx3 as 3xmxn.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( ndarray +) – +
    +

    Image loaded as a Numpy array.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
def load_image(fn, normalizeby = 0., torch_style = False):
+    """
+    Definition to load an image from a given location as a torch tensor.
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    normalizeby  : float or optional
+                   Value to to normalize images with. Default value of zero will lead to no normalization.
+    torch_style  : bool or optional
+                   If set True, it will load an image mxnx3 as 3xmxn.
+
+    Returns
+    -------
+    image        :  ndarray
+                    Image loaded as a Numpy array.
+
+    """
+    image = odak.tools.load_image(fn, normalizeby = normalizeby, torch_style = torch_style)
+    image = torch.from_numpy(image).float()
+    return image
+
+
+
+ +
+ +
+ + +

+ resize(image, multiplier=0.5, mode='nearest') + +

+ + +
+ +

Definition to resize an image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image with MxNx3 resolution.
    +
    +
    +
  • +
  • + multiplier + – +
    +
            Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).
    +
    +
    +
  • +
  • + mode + – +
    +
            Mode to be used in scaling, nearest, bilinear, etc.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_image ( tensor +) – +
    +

    Resized image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
def resize(image, multiplier = 0.5, mode = 'nearest'):
+    """
+    Definition to resize an image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image with MxNx3 resolution.
+    multiplier    : float
+                    Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).
+    mode          : str
+                    Mode to be used in scaling, nearest, bilinear, etc.
+
+    Returns
+    -------
+    new_image     : torch.tensor
+                    Resized image.
+
+    """
+    scale = torch.nn.Upsample(scale_factor = multiplier, mode = mode)
+    new_image = torch.zeros((int(image.shape[0] * multiplier), int(image.shape[1] * multiplier), 3)).to(image.device)
+    for i in range(3):
+        cache = image[:,:,i].unsqueeze(0)
+        cache = cache.unsqueeze(0)
+        new_cache = scale(cache).unsqueeze(0)
+        new_image[:,:,i] = new_cache.unsqueeze(0)
+    return new_image
+
+
+
+ +
+ +
+ + +

+ save_image(fn, img, cmin=0, cmax=255, color_depth=8) + +

+ + +
+ +

Definition to save a torch tensor as an image.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + img + – +
    +
           A numpy array with NxMx3 or NxMx1 shapes.
    +
    +
    +
  • +
  • + cmin + – +
    +
           Minimum value that will be interpreted as 0 level in the final image.
    +
    +
    +
  • +
  • + cmax + – +
    +
           Maximum value that will be interpreted as 255 level in the final image.
    +
    +
    +
  • +
  • + color_depth + – +
    +
           Color depth of an image. Default is eight.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +bool ( bool +) – +
    +

    True if successful.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):
+    """
+    Definition to save a torch tensor as an image.
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    img          : ndarray
+                   A numpy array with NxMx3 or NxMx1 shapes.
+    cmin         : int
+                   Minimum value that will be interpreted as 0 level in the final image.
+    cmax         : int
+                   Maximum value that will be interpreted as 255 level in the final image.
+    color_depth  : int
+                   Color depth of an image. Default is eight.
+
+
+    Returns
+    ----------
+    bool         :  bool
+                    True if successful.
+
+    """
+    if len(img.shape) ==  4:
+        img = img.squeeze(0)
+    if len(img.shape) > 2 and torch.argmin(torch.tensor(img.shape)) == 0:
+        new_img = torch.zeros(img.shape[1], img.shape[2], img.shape[0]).to(img.device)
+        for i in range(img.shape[0]):
+            new_img[:, :, i] = img[i].detach().clone()
+        img = new_img.detach().clone()
+    img = img.cpu().detach().numpy()
+    return odak.tools.save_image(fn, img, cmin = cmin, cmax = cmax, color_depth = color_depth)
+
+
+
+ +
+ +
+ + +

+ save_torch_tensor(fn, tensor) + +

+ + +
+ +

Definition to save a torch tensor.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + tensor + – +
    +
           Torch tensor to be saved.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
def save_torch_tensor(fn, tensor):
+    """
+    Definition to save a torch tensor.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    tensor       : torch.tensor
+                   Torch tensor to be saved.
+    """ 
+    torch.save(tensor, expanduser(fn))
+
+
+
+ +
+ +
+ + +

+ torch_load(fn) + +

+ + +
+ +

Definition to load a torch files (*.pt).

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +data ( any +) – +
    +

    See torch.load() for more.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/file.py +
def torch_load(fn):
+    """
+    Definition to load a torch files (*.pt).
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+
+    Returns
+    -------
+    data         : any
+                   See torch.load() for more.
+    """  
+    data = torch.load(expanduser(fn))
+    return data
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ histogram_loss(frame, ground_truth, bins=32, limits=[0.0, 1.0]) + +

+ + +
+ +

Function for evaluating a frame against a target using histogram.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
               Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + ground_truth + – +
    +
               Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].
    +
    +
    +
  • +
  • + bins + – +
    +
               Number of bins.
    +
    +
    +
  • +
  • + limits + – +
    +
               Limits.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]):
+    """
+    Function for evaluating a frame against a target using histogram.
+
+    Parameters
+    ----------
+    frame            : torch.tensor
+                       Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    ground_truth     : torch.tensor
+                       Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].
+    bins             : int
+                       Number of bins.
+    limits           : list
+                       Limits.
+
+    Returns
+    -------
+    loss             : float
+                       Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0).unsqueeze(0)
+    elif len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+
+    if len(ground_truth.shape) == 2:
+        ground_truth = ground_truth.unsqueeze(0).unsqueeze(0)
+    elif len(ground_truth.shape) == 3:
+        ground_truth = ground_truth.unsqueeze(0)
+
+    histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device)
+    histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device)
+
+    l2 = torch.nn.MSELoss()
+
+    for i in range(frame.shape[1]):
+        histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])
+        histogram_ground_truth[i] = torch.histc(ground_truth[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])
+
+    loss = l2(histogram_frame, histogram_ground_truth)
+
+    return loss
+
+
+
+ +
+ +
+ + +

+ michelson_contrast(image, roi_high, roi_low) + +

+ + +
+ +

A function to calculate michelson contrast ratio of given region of interests of the image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
  • + roi_high + – +
    +
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
  • + roi_low + – +
    +
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Michelson contrast for the given regions. [1] or [3] depending on input image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def michelson_contrast(image, roi_high, roi_low):
+    """
+    A function to calculate michelson contrast ratio of given region of interests of the image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].
+    roi_high      : torch.tensor
+                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
+    roi_low       : torch.tensor
+                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Michelson contrast for the given regions. [1] or [3] depending on input image.
+    """
+    if len(image.shape) == 2:
+        image = image.unsqueeze(0)
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]
+    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]
+    high = torch.mean(region_high, dim = (2, 3))
+    low = torch.mean(region_low, dim = (2, 3))
+    result = (high - low) / (high + low)
+    return result.squeeze(0)
+
+
+
+ +
+ +
+ + +

+ multi_scale_total_variation_loss(frame, levels=3) + +

+ + +
+ +

Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
  • + levels + – +
    +
            Number of levels to go in the image pyriamid.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def multi_scale_total_variation_loss(frame, levels = 3):
+    """
+    Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.
+
+    Parameters
+    ----------
+    frame         : torch.tensor
+                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
+    levels        : int
+                    Number of levels to go in the image pyriamid.
+
+    Returns
+    -------
+    loss          : float
+                    Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0)
+    if len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+    scale = torch.nn.Upsample(scale_factor = 0.5, mode = 'nearest')
+    level = frame
+    loss = 0
+    for i in range(levels):
+        if i != 0:
+           level = scale(level)
+        loss += total_variation_loss(level) 
+    return loss
+
+
+
+ +
+ +
+ + +

+ radial_basis_function(value, epsilon=0.5) + +

+ + +
+ +

Function to pass a value into radial basis function with Gaussian description.

+ + +

Parameters:

+
    +
  • + value + – +
    +
               Value(s) to pass to the radial basis function.
    +
    +
    +
  • +
  • + epsilon + – +
    +
               Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output values.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def radial_basis_function(value, epsilon = 0.5):
+    """
+    Function to pass a value into radial basis function with Gaussian description.
+
+    Parameters
+    ----------
+    value            : torch.tensor
+                       Value(s) to pass to the radial basis function. 
+    epsilon          : float
+                       Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).
+
+    Returns
+    -------
+    output           : torch.tensor
+                       Output values.
+    """
+    output = torch.exp((-(epsilon * value)**2))
+    return output
+
+
+
+ +
+ +
+ + +

+ total_variation_loss(frame) + +

+ + +
+ +

Function for evaluating a frame against a target using total variation approach.

+ + +

Parameters:

+
    +
  • + frame + – +
    +
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( float +) – +
    +

    Loss from evaluation.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
def total_variation_loss(frame):
+    """
+    Function for evaluating a frame against a target using total variation approach.
+
+    Parameters
+    ----------
+    frame         : torch.tensor
+                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].
+
+    Returns
+    -------
+    loss          : float
+                    Loss from evaluation.
+    """
+    if len(frame.shape) == 2:
+        frame = frame.unsqueeze(0)
+    if len(frame.shape) == 3:
+        frame = frame.unsqueeze(0)
+    diff_x = frame[:, :, :, 1:] - frame[:, :, :, :-1]
+    diff_y = frame[:, :, 1:, :] - frame[:, :, :-1, :]
+    pixel_count = frame.shape[0] * frame.shape[1] * frame.shape[2] * frame.shape[3]
+    loss = ((diff_x ** 2).sum() + (diff_y ** 2).sum()) / pixel_count
+    return loss
+
+
+
+ +
+ +
+ + +

+ weber_contrast(image, roi_high, roi_low) + +

+ + +
+ +

A function to calculate weber contrast ratio of given region of interests of the image.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + roi_high + – +
    +
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
  • + roi_low + – +
    +
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Weber contrast for given regions. [1] or [3] depending on input image.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def weber_contrast(image, roi_high, roi_low):
+    """
+    A function to calculate weber contrast ratio of given region of interests of the image.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].
+    roi_high      : torch.tensor
+                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].
+    roi_low       : torch.tensor
+                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].
+
+    Returns
+    -------
+    result        : torch.tensor
+                    Weber contrast for given regions. [1] or [3] depending on input image.
+    """
+    if len(image.shape) == 2:
+        image = image.unsqueeze(0)
+    if len(image.shape) == 3:
+        image = image.unsqueeze(0)
+    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]
+    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]
+    high = torch.mean(region_high, dim = (2, 3))
+    low = torch.mean(region_low, dim = (2, 3))
+    result = (high - low) / low
+    return result.squeeze(0)
+
+
+
+ +
+ +
+ + +

+ wrapped_mean_squared_error(image, ground_truth, reduction='mean') + +

+ + +
+ +

A function to calculate the wrapped mean squared error between predicted and target angles.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + ground_truth + – +
    +
            Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
    +
    +
    +
  • +
  • + reduction + – +
    +
            Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +wmse ( tensor +) – +
    +

    The calculated wrapped mean squared error.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/loss.py +
def wrapped_mean_squared_error(image, ground_truth, reduction = 'mean'):
+    """
+    A function to calculate the wrapped mean squared error between predicted and target angles.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    ground_truth  : torch.tensor
+                    Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].
+    reduction     : str
+                    Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.
+
+    Returns
+    -------
+    wmse        : torch.tensor
+                  The calculated wrapped mean squared error. 
+    """
+    sin_diff = torch.sin(image) - torch.sin(ground_truth)
+    cos_diff = torch.cos(image) - torch.cos(ground_truth)
+    loss = (sin_diff**2 + cos_diff**2)
+
+    if reduction == 'mean':
+        return loss.mean()
+    elif reduction == 'sum':
+        return loss.sum()
+    else:
+        raise ValueError("Invalid reduction type. Choose 'mean' or 'sum'.")
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3], padding='same') + +

+ + +
+ +

A definition to blur a field using a Gaussian kernel.

+ + +

Parameters:

+
    +
  • + field + – +
    +
            MxN field.
    +
    +
    +
  • +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + padding + – +
    +
            Padding value, see torch.nn.functional.conv2d() for more.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +blurred_field ( tensor +) – +
    +

    Blurred field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'):
+    """
+    A definition to blur a field using a Gaussian kernel.
+
+    Parameters
+    ----------
+    field         : torch.tensor
+                    MxN field.
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+    padding       : int or string
+                    Padding value, see torch.nn.functional.conv2d() for more.
+
+    Returns
+    ----------
+    blurred_field : torch.tensor
+                    Blurred field.
+    """
+    kernel = generate_2d_gaussian(kernel_length, nsigma).to(field.device)
+    kernel = kernel.unsqueeze(0).unsqueeze(0)
+    if len(field.shape) == 2:
+        field = field.view(1, 1, field.shape[-2], field.shape[-1])
+    blurred_field = torch.nn.functional.conv2d(field, kernel, padding='same')
+    if field.shape[1] == 1:
+        blurred_field = blurred_field.view(
+                                           blurred_field.shape[-2],
+                                           blurred_field.shape[-1]
+                                          )
+    return blurred_field
+
+
+
+ +
+ +
+ + +

+ convolve2d(field, kernel) + +

+ + +
+ +

Definition to convolve a field with a kernel by multiplying in frequency space.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field with MxN shape.
    +
    +
    +
  • +
  • + kernel + – +
    +
          Input kernel with MxN shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( tensor +) – +
    +

    Convolved field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def convolve2d(field, kernel):
+    """
+    Definition to convolve a field with a kernel by multiplying in frequency space.
+
+    Parameters
+    ----------
+    field       : torch.tensor
+                  Input field with MxN shape.
+    kernel      : torch.tensor
+                  Input kernel with MxN shape.
+
+    Returns
+    ----------
+    new_field   : torch.tensor
+                  Convolved field.
+    """
+    fr = torch.fft.fft2(field)
+    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))
+    m, n = fr.shape
+    new_field = torch.real(torch.fft.ifft2(fr*fr2))
+    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))
+    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))
+    return new_field
+
+
+
+ +
+ +
+ + +

+ correlation_2d(first_tensor, second_tensor) + +

+ + +
+ +

Definition to calculate the correlation between two tensors.

+ + +

Parameters:

+
    +
  • + first_tensor + – +
    +
            First tensor.
    +
    +
    +
  • +
  • + second_tensor + (tensor) + – +
    +
            Second tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +correlation ( tensor +) – +
    +

    Correlation between the two tensors.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def correlation_2d(first_tensor, second_tensor):
+    """
+    Definition to calculate the correlation between two tensors.
+
+    Parameters
+    ----------
+    first_tensor  : torch.tensor
+                    First tensor.
+    second_tensor : torch.tensor
+                    Second tensor.
+
+    Returns
+    ----------
+    correlation   : torch.tensor
+                    Correlation between the two tensors.
+    """
+    fft_first_tensor = (torch.fft.fft2(first_tensor))
+    fft_second_tensor = (torch.fft.fft2(second_tensor))
+    conjugate_second_tensor = torch.conj(fft_second_tensor)
+    result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor))
+    return result
+
+
+
+ +
+ +
+ + +

+ crop_center(field, size=None) + +

+ + +
+ +

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.
    +
    +
    +
  • +
  • + size + – +
    +
          Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cropped ( ndarray +) – +
    +

    Cropped version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def crop_center(field, size = None):
+    """
+    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.
+    size        : list
+                  Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).
+
+    Returns
+    ----------
+    cropped     : ndarray
+                  Cropped version of the input field.
+    """
+    orig_resolution = field.shape
+    if len(field.shape) < 3:
+        field = field.unsqueeze(0)
+    if len(field.shape) < 4:
+        field = field.unsqueeze(0)
+    permute_flag = False
+    if field.shape[-1] < 5:
+        permute_flag = True
+        field = field.permute(0, 3, 1, 2)
+    if type(size) == type(None):
+        qx = int(field.shape[-2] // 4)
+        qy = int(field.shape[-1] // 4)
+        cropped_padded = field[:, :, qx: qx + field.shape[-2] // 2, qy:qy + field.shape[-1] // 2]
+    else:
+        cx = int(field.shape[-2] // 2)
+        cy = int(field.shape[-1] // 2)
+        hx = int(size[-2] // 2)
+        hy = int(size[-1] // 2)
+        cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy]
+    cropped = cropped_padded
+    if permute_flag:
+        cropped = cropped.permute(0, 2, 3, 1)
+    if len(orig_resolution) == 2:
+        cropped = cropped_padded.squeeze(0).squeeze(0)
+    if len(orig_resolution) == 3:
+        cropped = cropped_padded.squeeze(0)
+    return cropped
+
+
+
+ +
+ +
+ + +

+ generate_2d_dirac_delta(kernel_length=[21, 21], a=[3, 3], mu=[0, 0], theta=0, normalize=False) + +

+ + +
+ +

Generate 2D Dirac delta function by using Gaussian distribution. +Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Dirac delta function along X and Y axes.
    +
    +
    +
  • +
  • + a + – +
    +
            The scale factor in Gaussian distribution to approximate the Dirac delta function. 
    +        As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.
    +
    +
    +
  • +
  • + mu + – +
    +
            Mu of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + theta + – +
    +
            The rotation angle of the 2D Dirac delta function.
    +
    +
    +
  • +
  • + normalize + – +
    +
            If set True, normalize the output.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( tensor +) – +
    +

    Generated 2D Dirac delta function.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def generate_2d_dirac_delta(
+                            kernel_length = [21, 21],
+                            a = [3, 3],
+                            mu = [0, 0],
+                            theta = 0,
+                            normalize = False
+                           ):
+    """
+    Generate 2D Dirac delta function by using Gaussian distribution.
+    Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Dirac delta function along X and Y axes.
+    a             : list
+                    The scale factor in Gaussian distribution to approximate the Dirac delta function. 
+                    As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.
+    mu            : list
+                    Mu of the Gaussian kernel along X and Y axes.
+    theta         : float
+                    The rotation angle of the 2D Dirac delta function.
+    normalize     : bool
+                    If set True, normalize the output.
+
+    Returns
+    ----------
+    kernel_2d     : torch.tensor
+                    Generated 2D Dirac delta function.
+    """
+    x = torch.linspace(-kernel_length[0] / 2., kernel_length[0] / 2., kernel_length[0])
+    y = torch.linspace(-kernel_length[1] / 2., kernel_length[1] / 2., kernel_length[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    X = X - mu[0]
+    Y = Y - mu[1]
+    theta = torch.as_tensor(theta)
+    X_rot = X * torch.cos(theta) - Y * torch.sin(theta)
+    Y_rot = X * torch.sin(theta) + Y * torch.cos(theta)
+    kernel_2d = (1 / (abs(a[0] * a[1]) * torch.pi)) * torch.exp(-((X_rot / a[0]) ** 2 + (Y_rot / a[1]) ** 2))
+    if normalize:
+        kernel_2d = kernel_2d / kernel_2d.max()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3], mu=[0, 0], normalize=False) + +

+ + +
+ +

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + mu + – +
    +
            Mu of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + normalize + – +
    +
            If set True, normalize the output.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( tensor +) – +
    +

    Generated Gaussian kernel.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):
+    """
+    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+    mu            : list
+                    Mu of the Gaussian kernel along X and Y axes.
+    normalize     : bool
+                    If set True, normalize the output.
+
+    Returns
+    ----------
+    kernel_2d     : torch.tensor
+                    Generated Gaussian kernel.
+    """
+    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])
+    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    if nsigma[0] == 0:
+        nsigma[0] = 1e-5
+    if nsigma[1] == 0:
+        nsigma[1] = 1e-5
+    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))
+    if normalize:
+        kernel_2d = kernel_2d / kernel_2d.max()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ quantize(image_field, bits=8, limits=[0.0, 1.0]) + +

+ + +
+ +

Definition to quantize a image field (0-255, 8 bit) to a certain bits level.

+ + +

Parameters:

+
    +
  • + image_field + (tensor) + – +
    +
          Input image field between any range.
    +
    +
    +
  • +
  • + bits + – +
    +
          A value in between one to eight.
    +
    +
    +
  • +
  • + limits + – +
    +
          The minimum and maximum of the image_field variable.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( tensor +) – +
    +

    Quantized image field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
def quantize(image_field, bits = 8, limits = [0., 1.]):
+    """ 
+    Definition to quantize a image field (0-255, 8 bit) to a certain bits level.
+
+    Parameters
+    ----------
+    image_field : torch.tensor
+                  Input image field between any range.
+    bits        : int
+                  A value in between one to eight.
+    limits      : list
+                  The minimum and maximum of the image_field variable.
+
+    Returns
+    ----------
+    new_field   : torch.tensor
+                  Quantized image field.
+    """
+    normalized_field = (image_field - limits[0]) / (limits[1] - limits[0])
+    divider = 2 ** bits
+    new_field = normalized_field * divider
+    new_field = new_field.int()
+    return new_field
+
+
+
+ +
+ +
+ + +

+ zero_pad(field, size=None, method='center') + +

+ + +
+ +

Definition to zero pad a MxN array to 2Mx2N array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
                Input field MxN or KxJxMxN or KxMxNxJ array.
    +
    +
    +
  • +
  • + size + – +
    +
                Size to be zeropadded (e.g., [m, n], last two dimensions only).
    +
    +
    +
  • +
  • + method + – +
    +
                Zeropad either by placing the content to center or to the left.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field_zero_padded ( ndarray +) – +
    +

    Zeropadded version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/matrix.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
def zero_pad(field, size = None, method = 'center'):
+    """
+    Definition to zero pad a MxN array to 2Mx2N array.
+
+    Parameters
+    ----------
+    field             : ndarray
+                        Input field MxN or KxJxMxN or KxMxNxJ array.
+    size              : list
+                        Size to be zeropadded (e.g., [m, n], last two dimensions only).
+    method            : str
+                        Zeropad either by placing the content to center or to the left.
+
+    Returns
+    ----------
+    field_zero_padded : ndarray
+                        Zeropadded version of the input field.
+    """
+    orig_resolution = field.shape
+    if len(field.shape) < 3:
+        field = field.unsqueeze(0)
+    if len(field.shape) < 4:
+        field = field.unsqueeze(0)
+    permute_flag = False
+    if field.shape[-1] < 5:
+        permute_flag = True
+        field = field.permute(0, 3, 1, 2)
+    if type(size) == type(None):
+        resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]]
+    else:
+        resolution = [field.shape[0], field.shape[1], size[0], size[1]]
+    field_zero_padded = torch.zeros(resolution, device = field.device, dtype = field.dtype)
+    if method == 'center':
+       start = [
+                resolution[-2] // 2 - field.shape[-2] // 2,
+                resolution[-1] // 2 - field.shape[-1] // 2
+               ]
+       field_zero_padded[
+                         :, :,
+                         start[0] : start[0] + field.shape[-2],
+                         start[1] : start[1] + field.shape[-1]
+                         ] = field
+    elif method == 'left':
+       field_zero_padded[
+                         :, :,
+                         0: field.shape[-2],
+                         0: field.shape[-1]
+                        ] = field
+    if permute_flag == True:
+        field_zero_padded = field_zero_padded.permute(0, 2, 3, 1)
+    if len(orig_resolution) == 2:
+        field_zero_padded = field_zero_padded.squeeze(0).squeeze(0)
+    if len(orig_resolution) == 3:
+        field_zero_padded = field_zero_padded.squeeze(0)
+    return field_zero_padded
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( tensor +) – +
    +

    Samples generated.

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix at X axis.

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix at Y axis.

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix at Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/sample.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
def grid_sample(
+                no = [10, 10],
+                size = [100., 100.], 
+                center = [0., 0., 0.], 
+                angles = [0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    -------
+    samples     : torch.tensor
+                  Samples generated.
+    rotx        : torch.tensor
+                  Rotation matrix at X axis.
+    roty        : torch.tensor
+                  Rotation matrix at Y axis.
+    rotz        : torch.tensor
+                  Rotation matrix at Z axis.
+    """
+    center = torch.tensor(center)
+    angles = torch.tensor(angles)
+    size = torch.tensor(size)
+    samples = torch.zeros((no[0], no[1], 3))
+    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])
+    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    samples[:, :, 0] = X.detach().clone()
+    samples[:, :, 1] = Y.detach().clone()
+    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))
+    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)
+    return samples, rotx, roty, rotz
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_rotation_matrix(tilt_angles=[0.0, 0.0, 0.0], tilt_order='XYZ') + +

+ + +
+ +

Function to generate rotation matrix for given tilt angles and tilt order.

+ + +

Parameters:

+
    +
  • + tilt_angles + – +
    +
                 Tilt angles in degrees along XYZ axes.
    +
    +
    +
  • +
  • + tilt_order + – +
    +
                 Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotmat ( tensor +) – +
    +

    Rotation matrix.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def get_rotation_matrix(tilt_angles = [0., 0., 0.], tilt_order = 'XYZ'):
+    """
+    Function to generate rotation matrix for given tilt angles and tilt order.
+
+
+    Parameters
+    ----------
+    tilt_angles        : list
+                         Tilt angles in degrees along XYZ axes.
+    tilt_order         : str
+                         Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).
+
+    Returns
+    -------
+    rotmat             : torch.tensor
+                         Rotation matrix.
+    """
+    rotx = rotmatx(tilt_angles[0])
+    roty = rotmaty(tilt_angles[1])
+    rotz = rotmatz(tilt_angles[2])
+    if tilt_order =='XYZ':
+        rotmat = torch.mm(rotz,torch.mm(roty, rotx))
+    elif tilt_order == 'XZY':
+        rotmat = torch.mm(roty,torch.mm(rotz, rotx))
+    elif tilt_order == 'ZXY':
+        rotmat = torch.mm(roty,torch.mm(rotx, rotz))
+    elif tilt_order == 'YXZ':
+        rotmat = torch.mm(rotz,torch.mm(rotx, roty))
+    elif tilt_order == 'ZYX':
+         rotmat = torch.mm(rotx,torch.mm(roty, rotz))
+    return rotmat
+
+
+
+ +
+ +
+ + +

+ rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]])) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point with size of [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis.
    +       There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +       Expected size is [3] or [1, 3].
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +       Expected size is [3] or [1, 3] or [m, 3].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the rotation [1 x 3] or [m x 3].

    +
    +
  • +
  • +rotx ( tensor +) – +
    +

    Rotation matrix along X axis [3 x 3].

    +
    +
  • +
  • +roty ( tensor +) – +
    +

    Rotation matrix along Y axis [3 x 3].

    +
    +
  • +
  • +rotz ( tensor +) – +
    +

    Rotation matrix along Z axis [3 x 3].

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def rotate_points(
+                 point,
+                 angles = torch.tensor([[0, 0, 0]]), 
+                 mode='XYZ', 
+                 origin = torch.tensor([[0, 0, 0]]), 
+                 offset = torch.tensor([[0, 0, 0]])
+                ):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : torch.tensor
+                   A point with size of [3] or [1, 3] or [m, 3].
+    angles       : torch.tensor
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis.
+                   There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : torch.tensor
+                   Reference point for a rotation.
+                   Expected size is [3] or [1, 3].
+    offset       : torch.tensor
+                   Shift with the given offset.
+                   Expected size is [3] or [1, 3] or [m, 3].
+
+    Returns
+    ----------
+    result       : torch.tensor
+                   Result of the rotation [1 x 3] or [m x 3].
+    rotx         : torch.tensor
+                   Rotation matrix along X axis [3 x 3].
+    roty         : torch.tensor
+                   Rotation matrix along Y axis [3 x 3].
+    rotz         : torch.tensor
+                   Rotation matrix along Z axis [3 x 3].
+    """
+    origin = origin.to(point.device)
+    offset = offset.to(point.device)
+    if len(point.shape) == 1:
+        point = point.unsqueeze(0)
+    if len(angles.shape) == 1:
+        angles = angles.unsqueeze(0)
+    rotx = rotmatx(angles[:, 0])
+    roty = rotmaty(angles[:, 1])
+    rotz = rotmatz(angles[:, 2])
+    new_point = (point - origin).T
+    if mode == 'XYZ':
+        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T
+    elif mode == 'XZY':
+        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T
+    elif mode == 'YXZ':
+        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T
+    elif mode == 'ZXY':
+        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T
+    elif mode == 'ZYX':
+        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T
+    result += origin
+    result += offset
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ rotmatx(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotx ( tensor +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
def rotmatx(angle):
+    """
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    rotx         : torch.tensor
+                   Rotation matrix along X axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    rotx = torch.stack([
+                        torch.stack([ one,              zero,              zero]),
+                        torch.stack([zero,  torch.cos(angle), -torch.sin(angle)]),
+                        torch.stack([zero,  torch.sin(angle),  torch.cos(angle)])
+                       ]).reshape(3, 3)
+    return rotx
+
+
+
+ +
+ +
+ + +

+ rotmaty(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Y axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +roty ( tensor +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
def rotmaty(angle):
+    """
+    Definition to generate a rotation matrix along Y axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    roty         : torch.tensor
+                   Rotation matrix along Y axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    roty = torch.stack([
+                        torch.stack([ torch.cos(angle), zero, torch.sin(angle)]),
+                        torch.stack([             zero,  one,             zero]),
+                        torch.stack([-torch.sin(angle), zero, torch.cos(angle)])
+                       ]).reshape(3, 3)
+    return roty
+
+
+
+ +
+ +
+ + +

+ rotmatz(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Z axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotz ( tensor +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def rotmatz(angle):
+    """
+    Definition to generate a rotation matrix along Z axis.
+
+    Parameters
+    ----------
+    angle        : torch.tensor
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    rotz         : torch.tensor
+                   Rotation matrix along Z axis.
+    """
+    angle = torch.deg2rad(angle)
+    one = torch.ones(1, device = angle.device)
+    zero = torch.zeros(1, device = angle.device)
+    rotz = torch.stack([
+                        torch.stack([torch.cos(angle), -torch.sin(angle), zero]),
+                        torch.stack([torch.sin(angle),  torch.cos(angle), zero]),
+                        torch.stack([            zero,              zero,  one])
+                       ]).reshape(3,3)
+    return rotz
+
+
+
+ +
+ +
+ + +

+ tilt_towards(location, lookat) + +

+ + +
+ +

Definition to tilt surface normal of a plane towards a point.

+ + +

Parameters:

+
    +
  • + location + – +
    +
           Center of the plane to be tilted.
    +
    +
    +
  • +
  • + lookat + – +
    +
           Tilt towards this point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +angles ( list +) – +
    +

    Rotation angles in degrees.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/transformation.py +
def tilt_towards(location, lookat):
+    """
+    Definition to tilt surface normal of a plane towards a point.
+
+    Parameters
+    ----------
+    location     : list
+                   Center of the plane to be tilted.
+    lookat       : list
+                   Tilt towards this point.
+
+    Returns
+    ----------
+    angles       : list
+                   Rotation angles in degrees.
+    """
+    dx = location[0] - lookat[0]
+    dy = location[1] - lookat[1]
+    dz = location[2] - lookat[2]
+    dist = torch.sqrt(torch.tensor(dx ** 2 + dy ** 2 + dz ** 2))
+    phi = torch.atan2(torch.tensor(dy), torch.tensor(dx))
+    theta = torch.arccos(dz / dist)
+    angles = [0, float(torch.rad2deg(theta)), float(torch.rad2deg(phi))]
+    return angles
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ cross_product(vector1, vector2) + +

+ + +
+ +

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

+ + +

Parameters:

+
    +
  • + vector1 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + vector2 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( tensor +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def cross_product(vector1, vector2):
+    """
+    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product
+
+    Parameters
+    ----------
+    vector1      : torch.tensor
+                   A vector/ray.
+    vector2      : torch.tensor
+                   A vector/ray.
+
+    Returns
+    ----------
+    ray          : torch.tensor
+                   Array that contains starting points and cosines of a created ray.
+    """
+    angle = torch.cross(vector1[1].T, vector2[1].T)
+    angle = torch.tensor(angle)
+    ray = torch.tensor([vector1[0], angle], dtype=torch.float32)
+    return ray
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( Tensor +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : torch.Tensor
+                  First point in X,Y,Z.
+    point2      : torch.Tensor
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : torch.Tensor
+                  Distance in between given two points.
+    """
+    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1
+    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2
+
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))
+
+    return distance
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = torch.subtract(b, a)
+    p1a = torch.subtract(p1, a)
+    p2a = torch.subtract(p2, a)
+    cp1 = torch.cross(ba, p1a)
+    cp2 = torch.cross(ba, p2a)
+    test = torch.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/learn_wave/index.html b/odak/learn_wave/index.html new file mode 100644 index 00000000..5cdef466 --- /dev/null +++ b/odak/learn_wave/index.html @@ -0,0 +1,29590 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.learn.wave - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.learn.wave

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0) + +

+ + +
+ +

A definition to calculate convolution with Angular Spectrum method for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
+    """
+    A definition to calculate convolution with Angular Spectrum method for beam propagation.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Angular Spectrum',
+                               device = field.device
+                              )
+    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ +
+ + +

+ band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0) + +

+ + +
+ +

A definition to calculate bandlimited angular spectrum based beam propagation. For more +Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               A complex field.
    +           The expected size is [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field [m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def band_limited_angular_spectrum(
+                                  field,
+                                  k,
+                                  distance,
+                                  dx,
+                                  wavelength,
+                                  zero_padding = False,
+                                  aperture = 1.
+                                 ):
+    """
+    A definition to calculate bandlimited angular spectrum based beam propagation. For more 
+    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673`.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       A complex field.
+                       The expected size is [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field [m x n].
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Bandlimited Angular Spectrum',
+                               device = field.device
+                              )
+    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ +
+ + +

+ custom(field, kernel, zero_padding=False, aperture=1.0) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field [m x n].
    +
    +
    +
  • +
  • + kernel + – +
    +
               Custom complex kernel for beam propagation.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def custom(field, kernel, zero_padding = False, aperture = 1.):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field [m x n].
+    kernel           : torch.complex
+                       Custom complex kernel for beam propagation.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+
+    """
+    if type(kernel) == type(None):
+        H = torch.ones(field.shape).to(field.device)
+    else:
+        H = kernel * aperture
+    U1 = torch.fft.fftshift(torch.fft.fft2(field)) * aperture
+    if zero_padding == False:
+        U2 = H * U1
+    elif zero_padding == True:
+        U2 = zero_pad(H * U1)
+    result = torch.fft.ifft2(torch.fft.ifftshift(U2))
+    return result
+
+
+
+ +
+ +
+ + +

+ fraunhofer(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate light transport usin Fraunhofer approximation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def fraunhofer(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate light transport usin Fraunhofer approximation.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape[-1], field.shape[-2]
+    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)
+    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)
+    Y, X = torch.meshgrid(y, x, indexing='ij')
+    Z = torch.pow(X, 2) + torch.pow(Y, 2)
+    c = 1. / (1j * wavelength * distance) * torch.exp(1j * k * 0.5 / distance * Z)
+    c = c.to(field.device)
+    result = c * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(field))) * dx ** 2
+    return result
+
+
+
+ +
+ +
+ + +

+ gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel') + +

+ + +
+ +

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'Transfer Function Fresnel' +) + – +
    +
               Type of the propagation (see odak.learn.wave.propagate_beam).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( cfloat +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
  • +reconstruction ( cfloat +) – +
    +

    Calculated reconstruction using calculated hologram.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):
+    """
+    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.
+
+    Parameters
+    ----------
+    field            : torch.cfloat
+                       Complex field (MxN).
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    slm_range        : float
+                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
+    propagation_type : str
+                       Type of the propagation (see odak.learn.wave.propagate_beam).
+
+    Returns
+    -------
+    hologram         : torch.cfloat
+                       Calculated complex hologram.
+    reconstruction   : torch.cfloat
+                       Calculated reconstruction using calculated hologram. 
+    """
+    k = wavenumber(wavelength)
+    reconstruction = field
+    for i in range(n_iterations):
+        hologram = propagate_beam(
+            reconstruction, k, -distance, dx, wavelength, propagation_type)
+        reconstruction = propagate_beam(
+            hologram, k, distance, dx, wavelength, propagation_type)
+        reconstruction = set_amplitude(reconstruction, field)
+    reconstruction = propagate_beam(
+        hologram, k, distance, dx, wavelength, propagation_type)
+    return hologram, reconstruction
+
+
+
+ +
+ +
+ + +

+ get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu')) + +

+ + +
+ +

Helper function for odak.learn.wave.angular_spectrum.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( float +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
+    """
+    Helper function for odak.learn.wave.angular_spectrum.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+
+
+    Returns
+    -------
+    H                  : float
+                         Complex kernel in Fourier domain.
+    """
+    distance = torch.tensor([distance]).to(device)
+    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
+    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
+    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
+    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
+    H = H.to(device)
+    return H
+
+
+
+ +
+ +
+ + +

+ get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu')) + +

+ + +
+ +

Helper function for odak.learn.wave.band_limited_angular_spectrum.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( complex64 +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_band_limited_angular_spectrum_kernel(
+                                             nu,
+                                             nv,
+                                             dx = 8e-6,
+                                             wavelength = 515e-9,
+                                             distance = 0.,
+                                             device = torch.device('cpu')
+                                            ):
+    """
+    Helper function for odak.learn.wave.band_limited_angular_spectrum.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+
+
+    Returns
+    -------
+    H                  : torch.complex64
+                         Complex kernel in Fourier domain.
+    """
+    x = dx * float(nu)
+    y = dx * float(nv)
+    fx = torch.linspace(
+                        -1 / (2 * dx) + 0.5 / (2 * x),
+                         1 / (2 * dx) - 0.5 / (2 * x),
+                         nu,
+                         dtype = torch.float32,
+                         device = device
+                        )
+    fy = torch.linspace(
+                        -1 / (2 * dx) + 0.5 / (2 * y),
+                        1 / (2 * dx) - 0.5 / (2 * y),
+                        nv,
+                        dtype = torch.float32,
+                        device = device
+                       )
+    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
+    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))
+    distance = torch.tensor([distance], device = device)
+    H_exp = torch.mul(HH_exp, distance)
+    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength
+    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength
+    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()
+    H = generate_complex_field(H_filter, H_exp)
+    return H
+
+
+
+ +
+ +
+ + +

+ get_impulse_response_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), scale=1, aperture_samples=[20, 20, 5, 5]) + +

+ + +
+ +

Helper function for odak.learn.wave.impulse_response_fresnel.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
  • + scale + – +
    +
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    +
    +
    +
  • +
  • + aperture_samples + – +
    +
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( complex64 +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu'), scale = 1, aperture_samples = [20, 20, 5, 5]):
+    """
+    Helper function for odak.learn.wave.impulse_response_fresnel.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+    scale              : int
+                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
+    aperture_samples   : list
+                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
+
+    Returns
+    -------
+    H                  : torch.complex64
+                         Complex kernel in Fourier domain.
+    """
+    k = wavenumber(wavelength)
+    distance = torch.as_tensor(distance, device = device)
+    length_x, length_y = (torch.tensor(dx * nu, device = device), torch.tensor(dx * nv, device = device))
+    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
+    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
+    X, Y = torch.meshgrid(x, y, indexing = 'ij')
+    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device)
+    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device)
+    h = torch.zeros(nu * scale, nv * scale, dtype = torch.complex64, device = device)
+    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device)
+    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device)
+    for wx in tqdm(wxs):
+        for wy in wys:
+            for px in pxs:
+                for py in pys:
+                    r = (X + px - wx) ** 2 + (Y + py - wy) ** 2
+                    h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r) 
+    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
+    return H
+
+
+
+ +
+ +
+ + +

+ get_incoherent_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu')) + +

+ + +
+ +

Helper function for odak.learn.wave.angular_spectrum.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( float +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
+    """
+    Helper function for odak.learn.wave.angular_spectrum.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+
+
+    Returns
+    -------
+    H                  : float
+                         Complex kernel in Fourier domain.
+    """
+    distance = torch.tensor([distance]).to(device)
+    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)
+    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)
+    FY, FX = torch.meshgrid(fx, fy, indexing='ij')
+    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))
+    H_ptime = correlation_2d(H, H)
+    H = H_ptime.to(device)
+    return H
+
+
+
+ +
+ +
+ + +

+ get_light_kernels(wavelengths, distances, pixel_pitches, resolution=[1080, 1920], resolution_factor=1, samples=[50, 50, 5, 5], propagation_type='Bandlimited Angular Spectrum', kernel_type='spatial', device=torch.device('cpu')) + +

+ + +
+ +

Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

+ + +

Parameters:

+
    +
  • + wavelengths + – +
    +
                 A list of wavelengths.
    +
    +
    +
  • +
  • + distances + – +
    +
                 A list of propagation distances.
    +
    +
    +
  • +
  • + pixel_pitches + – +
    +
                 A list of pixel_pitches.
    +
    +
    +
  • +
  • + resolution + – +
    +
                 Resolution of the light transport kernel.
    +
    +
    +
  • +
  • + resolution_factor + – +
    +
                 If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
    +
    +
    +
  • +
  • + samples + – +
    +
                 If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
    +
    +
    +
  • +
  • + propagation_type + – +
    +
                 Propagation type. For more, see odak.learn.wave.propagate_beam().
    +
    +
    +
  • +
  • + kernel_type + – +
    +
                 If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device used for computation (i.e., cpu, cuda).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +light_kernels_amplitude ( tensor +) – +
    +

    Amplitudes of the light kernels generated [w x d x p x m x n].

    +
    +
  • +
  • +light_kernels_phase ( tensor +) – +
    +

    Phases of the light kernels generated [w x d x p x m x n].

    +
    +
  • +
  • +light_kernels_complex ( tensor +) – +
    +

    Complex light kernels generated [w x d x p x m x n].

    +
    +
  • +
  • +light_parameters ( tensor +) – +
    +

    Parameters of each pixel in light_kernels* [w x d x p x m x n x 5]. Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_light_kernels(
+                      wavelengths,
+                      distances,
+                      pixel_pitches,
+                      resolution = [1080, 1920],
+                      resolution_factor = 1,
+                      samples = [50, 50, 5, 5],
+                      propagation_type = 'Bandlimited Angular Spectrum',
+                      kernel_type = 'spatial',
+                      device = torch.device('cpu')
+                     ):
+    """
+    Utility function to request a tensor filled with light transport kernels according to the given optical configurations.
+
+    Parameters
+    ----------
+    wavelengths        : list
+                         A list of wavelengths.
+    distances          : list
+                         A list of propagation distances.
+    pixel_pitches      : list
+                         A list of pixel_pitches.
+    resolution         : list
+                         Resolution of the light transport kernel.
+    resolution_factor  : int
+                         If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().
+    samples            : list
+                         If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().
+    propagation_type   : str
+                         Propagation type. For more, see odak.learn.wave.propagate_beam().
+    kernel_type        : str
+                         If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.
+    device             : torch.device
+                         Device used for computation (i.e., cpu, cuda).
+
+    Returns
+    -------
+    light_kernels_amplitude : torch.tensor
+                              Amplitudes of the light kernels generated [w x d x p x m x n].
+    light_kernels_phase     : torch.tensor
+                              Phases of the light kernels generated [w x d x p x m x n].
+    light_kernels_complex   : torch.tensor
+                              Complex light kernels generated [w x d x p x m x n].
+    light_parameters        : torch.tensor
+                              Parameters of each pixel in light_kernels* [w x d x p x m x n x 5].  Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.
+    """
+    if propagation_type != 'Impulse Response Fresnel':
+        resolution_factor = 1
+    light_kernels_complex = torch.zeros(            
+                                        len(wavelengths),
+                                        len(distances),
+                                        len(pixel_pitches),
+                                        resolution[0] * resolution_factor,
+                                        resolution[1] * resolution_factor,
+                                        dtype = torch.complex64,
+                                        device = device
+                                       )
+    light_parameters = torch.zeros(
+                                   len(wavelengths),
+                                   len(distances),
+                                   len(pixel_pitches),
+                                   resolution[0] * resolution_factor,
+                                   resolution[1] * resolution_factor,
+                                   5,
+                                   dtype = torch.float32,
+                                   device = device
+                                  )
+    for wavelength_id, distance_id, pixel_pitch_id in itertools.product(
+                                                                        range(len(wavelengths)),
+                                                                        range(len(distances)),
+                                                                        range(len(pixel_pitches)),
+                                                                       ):
+        pixel_pitch = pixel_pitches[pixel_pitch_id]
+        wavelength = wavelengths[wavelength_id]
+        distance = distances[distance_id]
+        kernel_fourier = get_propagation_kernel(
+                                                nu = resolution[0],
+                                                nv = resolution[1],
+                                                dx = pixel_pitch,
+                                                wavelength = wavelength,
+                                                distance = distance,
+                                                device = device,
+                                                propagation_type = propagation_type,
+                                                scale = resolution_factor,
+                                                samples = samples
+                                               )
+        if kernel_type == 'spatial':
+            kernel = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(kernel_fourier)))
+        elif kernel_type == 'fourier':
+            kernel = kernel_fourier
+        else:
+            logging.warning('Unknown kernel type requested.')
+            raise ValueError('Unknown kernel type requested.')
+        kernel_amplitude = calculate_amplitude(kernel)
+        kernel_phase = calculate_phase(kernel) % (2 * torch.pi)
+        light_kernels_complex[wavelength_id, distance_id, pixel_pitch_id] = kernel
+        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 0] = wavelength
+        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 1] = distance
+        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 2] = pixel_pitch
+        x = torch.linspace(-1., 1., resolution[0] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[0]
+        y = torch.linspace(-1., 1., resolution[1] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[1]
+        X, Y = torch.meshgrid(x, y, indexing = 'ij')
+        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 3] = X
+        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 4] = Y
+    light_kernels_amplitude = calculate_amplitude(light_kernels_complex)
+    light_kernels_phase = calculate_phase(light_kernels_complex) % (2. * torch.pi)
+    return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters
+
+
+
+ +
+ +
+ + +

+ get_point_wise_impulse_response_fresnel_kernel(aperture_points, aperture_field, target_points, resolution, resolution_factor=1, wavelength=5.15e-07, distance=0.0, randomization=False, device=torch.device('cpu')) + +

+ + +
+ +

This function is a freeform point spread function calculation routine for an aperture defined with a complex field, aperture_field, and locations in space, aperture_points. +The point spread function is calculated over provided points, target_points. +The final result is reshaped to follow the provided resolution.

+ + +

Parameters:

+
    +
  • + aperture_points + – +
    +
                       Points representing an aperture in Euler space (XYZ) [m x 3].
    +
    +
    +
  • +
  • + aperture_field + – +
    +
                       Complex field for each point provided by `aperture_points` [1 x m].
    +
    +
    +
  • +
  • + target_points + – +
    +
                       Target points where the propagated field will be calculated [n x 1].
    +
    +
    +
  • +
  • + resolution + – +
    +
                       Final resolution that the propagated field will be reshaped [X x Y].
    +
    +
    +
  • +
  • + resolution_factor + – +
    +
                       Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                       Wavelength in meters.
    +
    +
    +
  • +
  • + randomization + – +
    +
                       If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
    +
    +
    +
  • +
  • + distance + – +
    +
                       Distance in meters.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +h ( float +) – +
    +

    Complex field in spatial domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_point_wise_impulse_response_fresnel_kernel(
+                                                   aperture_points,
+                                                   aperture_field,
+                                                   target_points,
+                                                   resolution,
+                                                   resolution_factor = 1,
+                                                   wavelength = 515e-9,
+                                                   distance = 0.,
+                                                   randomization = False,
+                                                   device = torch.device('cpu')
+                                                  ):
+    """
+    This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`.
+    The point spread function is calculated over provided points, `target_points`.
+    The final result is reshaped to follow the provided `resolution`.
+
+    Parameters
+    ----------
+    aperture_points          : torch.tensor
+                               Points representing an aperture in Euler space (XYZ) [m x 3].
+    aperture_field           : torch.tensor
+                               Complex field for each point provided by `aperture_points` [1 x m].
+    target_points            : torch.tensor
+                               Target points where the propagated field will be calculated [n x 1].
+    resolution               : list
+                               Final resolution that the propagated field will be reshaped [X x Y].
+    resolution_factor        : int
+                               Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.
+    wavelength               : float
+                               Wavelength in meters.
+    randomization            : bool
+                               If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.
+    distance                 : float
+                               Distance in meters.
+
+    Returns
+    -------
+    h                        : float
+                               Complex field in spatial domain.
+    """
+    device = aperture_field.device
+    k = wavenumber(wavelength)
+    if randomization:
+        pp = [
+              aperture_points[:, 0].max() - aperture_points[:, 0].min(),
+              aperture_points[:, 1].max() - aperture_points[:, 1].min()
+             ]
+        target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0]
+        target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1]
+    deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1)
+    deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1)
+    r = deltaX ** 2 + deltaY ** 2
+    h = torch.exp(1j * k / (2 * distance) * r) * aperture_field
+    h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor)
+    h = 1. / (1j * wavelength * distance) * h
+    return h
+
+
+
+ +
+ +
+ + +

+ get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1, samples=[20, 20, 5, 5]) + +

+ + +
+ +

Get propagation kernel for the propagation type.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
  • + propagation_type + – +
    +
                 Propagation type.
    +             The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
    +
    +
    +
  • +
  • + scale + – +
    +
                 Scale factor for scaled beam propagation.
    +
    +
    +
  • +
  • + samples + – +
    +
                 When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel ( tensor +) – +
    +

    Complex kernel for the given propagation type.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_propagation_kernel(
+                           nu, 
+                           nv, 
+                           dx = 8e-6, 
+                           wavelength = 515e-9, 
+                           distance = 0., 
+                           device = torch.device('cpu'), 
+                           propagation_type = 'Bandlimited Angular Spectrum', 
+                           scale = 1,
+                           samples = [20, 20, 5, 5]
+                          ):
+    """
+    Get propagation kernel for the propagation type.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+    propagation_type   : str
+                         Propagation type.
+                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.
+    scale              : int
+                         Scale factor for scaled beam propagation.
+    samples            : list
+                         When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
+
+
+    Returns
+    -------
+    kernel             : torch.tensor
+                         Complex kernel for the given propagation type.
+    """                                                      
+    logging.warning('Requested propagation kernel size for %s method with %s m distance, %s m pixel pitch, %s m wavelength, %s x %s resolutions, x%s scale and %s samples.'.format(propagation_type, distance, dx, nu, nv, scale, samples))
+    if propagation_type == 'Bandlimited Angular Spectrum':
+        kernel = get_band_limited_angular_spectrum_kernel(
+                                                          nu = nu,
+                                                          nv = nv,
+                                                          dx = dx,
+                                                          wavelength = wavelength,
+                                                          distance = distance,
+                                                          device = device
+                                                         )
+    elif propagation_type == 'Angular Spectrum':
+        kernel = get_angular_spectrum_kernel(
+                                             nu = nu,
+                                             nv = nv,
+                                             dx = dx,
+                                             wavelength = wavelength,
+                                             distance = distance,
+                                             device = device
+                                            )
+    elif propagation_type == 'Transfer Function Fresnel':
+        kernel = get_transfer_function_fresnel_kernel(
+                                                      nu = nu,
+                                                      nv = nv,
+                                                      dx = dx,
+                                                      wavelength = wavelength,
+                                                      distance = distance,
+                                                      device = device
+                                                     )
+    elif propagation_type == 'Impulse Response Fresnel':
+        kernel = get_impulse_response_fresnel_kernel(
+                                                     nu = nu, 
+                                                     nv = nv, 
+                                                     dx = dx, 
+                                                     wavelength = wavelength,
+                                                     distance = distance,
+                                                     device =  device,
+                                                     scale = scale,
+                                                     aperture_samples = samples
+                                                    )
+    elif propagation_type == 'Incoherent Angular Spectrum':
+        kernel = get_incoherent_angular_spectrum_kernel(
+                                                        nu = nu,
+                                                        nv = nv, 
+                                                        dx = dx, 
+                                                        wavelength = wavelength, 
+                                                        distance = distance,
+                                                        device = device
+                                                       )
+    elif propagation_type == 'Seperable Impulse Response Fresnel':
+        kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel(
+                                                                        nu = nu,
+                                                                        nv = nv,
+                                                                        dx = dx,
+                                                                        wavelength = wavelength,
+                                                                        distance = distance,
+                                                                        device = device,
+                                                                        scale = scale,
+                                                                        aperture_samples = samples
+                                                                       )
+    else:
+        logging.warning('Propagation type not recognized')
+        assert True == False
+    return kernel
+
+
+
+ +
+ +
+ + +

+ get_seperable_impulse_response_fresnel_kernel(nu, nv, dx=3.74e-06, wavelength=5.15e-07, distance=0.0, scale=1, aperture_samples=[50, 50, 5, 5], device=torch.device('cpu')) + +

+ + +
+ +

Returns impulse response fresnel kernel in separable form.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
  • + scale + – +
    +
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
    +
    +
    +
  • +
  • + aperture_samples + – +
    +
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( complex64 +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
  • +h ( complex64 +) – +
    +

    Complex kernel in spatial domain.

    +
    +
  • +
  • +h_x ( complex64 +) – +
    +

    1D complex kernel in spatial domain along X axis.

    +
    +
  • +
  • +h_y ( complex64 +) – +
    +

    1D complex kernel in spatial domain along Y axis.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_seperable_impulse_response_fresnel_kernel(
+                                                  nu,
+                                                  nv,
+                                                  dx = 3.74e-6,
+                                                  wavelength = 515e-9,
+                                                  distance = 0.,
+                                                  scale = 1,
+                                                  aperture_samples = [50, 50, 5, 5],
+                                                  device = torch.device('cpu')
+                                                 ):
+    """
+    Returns impulse response fresnel kernel in separable form.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+    scale              : int
+                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).
+    aperture_samples   : list
+                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.
+
+    Returns
+    -------
+    H                  : torch.complex64
+                         Complex kernel in Fourier domain.
+    h                  : torch.complex64
+                         Complex kernel in spatial domain.
+    h_x                : torch.complex64
+                         1D complex kernel in spatial domain along X axis.
+    h_y                : torch.complex64
+                         1D complex kernel in spatial domain along Y axis.
+    """
+    k = wavenumber(wavelength)
+    distance = torch.as_tensor(distance, device = device)
+    length_x, length_y = (
+                          torch.tensor(dx * nu, device = device),
+                          torch.tensor(dx * nv, device = device)
+                         )
+    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)
+    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)
+    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0)
+    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1)
+    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1)
+    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0)
+    wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1)
+    wys = (wys - pys).reshape(1, -1).unsqueeze(1)
+
+    X = x.unsqueeze(-1).unsqueeze(-1)
+    Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
+    r_x = (X + wxs) ** 2
+    r_y = (Y + wys) ** 2
+    r = r_x + r_y
+    h_x = torch.exp(1j * k / (2 * distance) * r)
+    h_x = torch.sum(h_x, axis = (1, 2))
+
+    if nu != nv:
+        X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)
+        Y = y.unsqueeze(-1).unsqueeze(-1)
+        r_x = (X + wxs) ** 2
+        r_y = (Y + wys) ** 2
+        r = r_x + r_y
+        h_y = torch.exp(1j * k * r / (2 * distance))
+        h_y = torch.sum(h_y, axis = (1, 2))
+    else:
+        h_y = h_x.detach().clone()
+    h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0)
+    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]
+    return H, h, h_x, h_y
+
+
+
+ +
+ +
+ + +

+ get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu')) + +

+ + +
+ +

Helper function for odak.learn.wave.transfer_function_fresnel.

+ + +

Parameters:

+
    +
  • + nu + – +
    +
                 Resolution at X axis in pixels.
    +
    +
    +
  • +
  • + nv + – +
    +
                 Resolution at Y axis in pixels.
    +
    +
    +
  • +
  • + dx + – +
    +
                 Pixel pitch in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                 Wavelength in meters.
    +
    +
    +
  • +
  • + distance + – +
    +
                 Distance in meters.
    +
    +
    +
  • +
  • + device + – +
    +
                 Device, for more see torch.device().
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +H ( complex64 +) – +
    +

    Complex kernel in Fourier domain.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):
+    """
+    Helper function for odak.learn.wave.transfer_function_fresnel.
+
+    Parameters
+    ----------
+    nu                 : int
+                         Resolution at X axis in pixels.
+    nv                 : int
+                         Resolution at Y axis in pixels.
+    dx                 : float
+                         Pixel pitch in meters.
+    wavelength         : float
+                         Wavelength in meters.
+    distance           : float
+                         Distance in meters.
+    device             : torch.device
+                         Device, for more see torch.device().
+
+
+    Returns
+    -------
+    H                  : torch.complex64
+                         Complex kernel in Fourier domain.
+    """
+    distance = torch.tensor([distance]).to(device)
+    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)
+    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)
+    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')
+    k = wavenumber(wavelength)
+    H = torch.exp(-1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))
+    return H
+
+
+
+ +
+ +
+ + +

+ impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5]) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
  • + scale + – +
    +
               Resolution factor to scale generated kernel.
    +
    +
    +
  • +
  • + samples + – +
    +
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+    scale            : int
+                       Resolution factor to scale generated kernel.
+    samples          : list
+                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Impulse Response Fresnel',
+                               device = field.device,
+                               scale = scale,
+                               samples = samples
+                              )
+    if scale > 1:
+        field_amplitude = calculate_amplitude(field)
+        field_phase = calculate_phase(field)
+        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
+        field_scale_phase = torch.zeros_like(field_scale_amplitude)
+        field_scale_amplitude[::scale, ::scale] = field_amplitude
+        field_scale_phase[::scale, ::scale] = field_phase
+        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
+    else:
+        field_scale = field
+    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ +
+ + +

+ incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0) + +

+ + +
+ +

A definition to calculate incoherent beam propagation with Angular Spectrum method.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field [m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
+    """
+    A definition to calculate incoherent beam propagation with Angular Spectrum method.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field [m x n].
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Incoherent Angular Spectrum',
+                               device = field.device
+                              )
+    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ +
+ + +

+ point_wise(target, wavelength, distance, dx, device, lens_size=401) + +

+ + +
+ +

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

+ + +

Parameters:

+
    +
  • + target + – +
    +
               float input target to be converted into a hologram (Target should be in range of 0 and 1).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + device + – +
    +
               Device type (cuda or cpu)`.
    +
    +
    +
  • +
  • + lens_size + – +
    +
               Size of lens for masking sub holograms(in pixels).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( cfloat +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def point_wise(target, wavelength, distance, dx, device, lens_size=401):
+    """
+    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. "Holographic near-eye displays for virtual and augmented reality." ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.
+
+    Parameters
+    ----------
+    target           : torch.float
+                       float input target to be converted into a hologram (Target should be in range of 0 and 1).
+    wavelength       : float
+                       Wavelength of the electric field.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    device           : torch.device
+                       Device type (cuda or cpu)`.
+    lens_size        : int
+                       Size of lens for masking sub holograms(in pixels).
+
+    Returns
+    -------
+    hologram         : torch.cfloat
+                       Calculated complex hologram.
+    """
+    target = zero_pad(target)
+    nx, ny = target.shape
+    k = wavenumber(wavelength)
+    ones = torch.ones(target.shape, requires_grad=False).to(device)
+    x = torch.linspace(-nx/2, nx/2, nx).to(device)
+    y = torch.linspace(-ny/2, ny/2, ny).to(device)
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    Z = (X**2+Y**2)**0.5
+    mask = (torch.abs(Z) <= lens_size)
+    mask[mask > 1] = 1
+    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)
+    A = torch.nan_to_num(target**0.5, nan=0.0)
+    fz = mask*fz
+    FA = torch.fft.fft2(torch.fft.fftshift(A))
+    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))
+    H = torch.mul(FA, FFZ)
+    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))
+    hologram = crop_center(hologram)
+    return hologram
+
+
+
+ +
+ +
+ + +

+ propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0, scale=1, samples=[20, 20, 5, 5]) + +

+ + +
+ +

Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field [m x n].
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'Bandlimited Angular Spectrum' +) + – +
    +
               Type of the propagation.
    +           The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
    +
    +
    +
  • +
  • + kernel + – +
    +
               Custom complex kernel.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero padding the input field if the first item in the list set True.
    +           Zero padding in the Fourier domain if the second item in the list set to True.
    +           Cropping the result with half resolution if the third item in the list is set to true.
    +           Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
    +           If provided as a floating point 1, there will be no aperture in Fourier domain.
    +
    +
    +
  • +
  • + scale + – +
    +
               Resolution factor to scale generated kernel.
    +
    +
    +
  • +
  • + samples + – +
    +
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field [m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
def propagate_beam(
+                   field,
+                   k,
+                   distance,
+                   dx,
+                   wavelength,
+                   propagation_type='Bandlimited Angular Spectrum',
+                   kernel = None,
+                   zero_padding = [True, False, True],
+                   aperture = 1.,
+                   scale = 1,
+                   samples = [20, 20, 5, 5]
+                  ):
+    """
+    Definitions for various beam propagation methods mostly in accordence with "Computational Fourier Optics" by David Vuelz.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field [m x n].
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    propagation_type : str
+                       Type of the propagation.
+                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.
+    kernel           : torch.complex
+                       Custom complex kernel.
+    zero_padding     : list
+                       Zero padding the input field if the first item in the list set True.
+                       Zero padding in the Fourier domain if the second item in the list set to True.
+                       Cropping the result with half resolution if the third item in the list is set to true.
+                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.
+    aperture         : torch.tensor
+                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.
+                       If provided as a floating point 1, there will be no aperture in Fourier domain.
+    scale            : int
+                       Resolution factor to scale generated kernel.
+    samples          : list
+                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field [m x n].
+    """
+    if zero_padding[0]:
+        field = zero_pad(field)
+    if propagation_type == 'Angular Spectrum':
+        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Bandlimited Angular Spectrum':
+        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Impulse Response Fresnel':
+        result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
+    elif propagation_type == 'Seperable Impulse Response Fresnel':
+        result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)
+    elif propagation_type == 'Transfer Function Fresnel':
+        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'custom':
+        result = custom(field, kernel, zero_padding[1], aperture = aperture)
+    elif propagation_type == 'Fraunhofer':
+        result = fraunhofer(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Incoherent Angular Spectrum':
+        result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)
+    else:
+        logging.warning('Propagation type not recognized')
+        assert True == False
+    if zero_padding[2]:
+        result = crop_center(result)
+    return result
+
+
+
+ +
+ +
+ + +

+ seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5]) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
  • + scale + – +
    +
               Resolution factor to scale generated kernel.
    +
    +
    +
  • +
  • + samples + – +
    +
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+    scale            : int
+                       Resolution factor to scale generated kernel.
+    samples          : list
+                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Seperable Impulse Response Fresnel',
+                               device = field.device,
+                               scale = scale,
+                               samples = samples
+                              )
+    if scale > 1:
+        field_amplitude = calculate_amplitude(field)
+        field_phase = calculate_phase(field)
+        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)
+        field_scale_phase = torch.zeros_like(field_scale_amplitude)
+        field_scale_amplitude[::scale, ::scale] = field_amplitude
+        field_scale_phase[::scale, ::scale] = field_phase
+        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)
+    else:
+        field_scale = field
+    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ +
+ + +

+ shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None) + +

+ + +
+ +

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

+ + +

Parameters:

+
    +
  • + phase + – +
    +
               Phase value of a phase-only hologram.
    +
    +
    +
  • +
  • + depth_shift + – +
    +
               Distance in meters.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
               Pixel pitch size in meters.
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of light.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'Transfer Function Fresnel' +) + – +
    +
               Beam propagation type. For more see odak.learn.wave.propagate_beam().
    +
    +
    +
  • +
  • + kernel_length + – +
    +
               Kernel length for the Gaussian blur kernel.
    +
    +
    +
  • +
  • + sigma + – +
    +
               Standard deviation for the Gaussian blur kernel.
    +
    +
    +
  • +
  • + amplitude + – +
    +
               Amplitude value of a complex hologram.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):
+    """
+    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.
+
+    Parameters
+    ----------
+    phase            : torch.tensor
+                       Phase value of a phase-only hologram.
+    depth_shift      : float
+                       Distance in meters.
+    pixel_pitch      : float
+                       Pixel pitch size in meters.
+    wavelength       : float
+                       Wavelength of light.
+    propagation_type : str
+                       Beam propagation type. For more see odak.learn.wave.propagate_beam().
+    kernel_length    : int
+                       Kernel length for the Gaussian blur kernel.
+    sigma            : float
+                       Standard deviation for the Gaussian blur kernel.
+    amplitude        : torch.tensor
+                       Amplitude value of a complex hologram.
+    """
+    if type(amplitude) == type(None):
+        amplitude = torch.ones_like(phase)
+    hologram = generate_complex_field(amplitude, phase)
+    k = wavenumber(wavelength)
+    hologram_padded = zero_pad(hologram)
+    shifted_field_padded = propagate_beam(
+                                          hologram_padded,
+                                          k,
+                                          depth_shift,
+                                          pixel_pitch,
+                                          wavelength,
+                                          propagation_type
+                                         )
+    shifted_field = crop_center(shifted_field_padded)
+    phase_shift = torch.exp(torch.tensor([-2 * torch.pi * depth_shift / wavelength]).to(phase.device))
+    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)
+    shifted_complex_hologram = shifted_field * shift
+
+    if kernel_length > 0 and sigma >0:
+        blur_kernel = generate_2d_gaussian(
+                                           [kernel_length, kernel_length],
+                                           [sigma, sigma]
+                                          ).to(phase.device)
+        blur_kernel = blur_kernel.unsqueeze(0)
+        blur_kernel = blur_kernel.unsqueeze(0)
+        field_imag = torch.imag(shifted_complex_hologram)
+        field_real = torch.real(shifted_complex_hologram)
+        field_imag = field_imag.unsqueeze(0)
+        field_imag = field_imag.unsqueeze(0)
+        field_real = field_real.unsqueeze(0)
+        field_real = field_real.unsqueeze(0)
+        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')
+        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')
+        shifted_complex_hologram = torch.complex(field_real, field_imag)
+        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
+        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)
+
+    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)
+    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])
+
+    shifted_phase = calculate_phase(shifted_complex_hologram)
+    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)
+
+    phase_offset = torch.arccos(shifted_amplitude)
+    phase_low = phase_zero_mean - phase_offset
+    phase_high = phase_zero_mean + phase_offset
+
+    phase_only = torch.zeros_like(phase)
+    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
+    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
+    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
+    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
+    return phase_only
+
+
+
+ +
+ +
+ + +

+ stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1) + +

+ + +
+ +

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

+ + +

Parameters:

+
    +
  • + target + – +
    +
                        Target field amplitude [m x n].
    +                    Keep the target values between zero and one.
    +
    +
    +
  • +
  • + wavelength + – +
    +
                        Set if the converted array requires gradient.
    +
    +
    +
  • +
  • + distance + – +
    +
                        Hologram plane distance wrt SLM plane.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
                        SLM pixel pitch in meters.
    +
    +
    +
  • +
  • + propagation_type + – +
    +
                        Type of the propagation (see odak.learn.wave.propagate_beam()).
    +
    +
    +
  • +
  • + n_iteration + – +
    +
                        Number of iteration.
    +
    +
    +
  • +
  • + loss_function + – +
    +
                        If none it is set to be l2 loss.
    +
    +
    +
  • +
  • + learning_rate + – +
    +
                        Learning rate.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( Tensor +) – +
    +

    Phase only hologram as torch array

    +
    +
  • +
  • +reconstruction_intensity ( Tensor +) – +
    +

    Reconstruction as torch array

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):
+    """
+    Definition to generate phase and reconstruction from target image via stochastic gradient descent.
+
+    Parameters
+    ----------
+    target                    : torch.Tensor
+                                Target field amplitude [m x n].
+                                Keep the target values between zero and one.
+    wavelength                : double
+                                Set if the converted array requires gradient.
+    distance                  : double
+                                Hologram plane distance wrt SLM plane.
+    pixel_pitch               : float
+                                SLM pixel pitch in meters.
+    propagation_type          : str
+                                Type of the propagation (see odak.learn.wave.propagate_beam()).
+    n_iteration:              : int
+                                Number of iteration.
+    loss_function:            : function
+                                If none it is set to be l2 loss.
+    learning_rate             : float
+                                Learning rate.
+
+    Returns
+    -------
+    hologram                  : torch.Tensor
+                                Phase only hologram as torch array
+
+    reconstruction_intensity  : torch.Tensor
+                                Reconstruction as torch array
+
+    """
+    phase = torch.randn_like(target, requires_grad = True)
+    k = wavenumber(wavelength)
+    optimizer = torch.optim.Adam([phase], lr = learning_rate)
+    if type(loss_function) == type(None):
+        loss_function = torch.nn.MSELoss()
+    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)
+    for i in t:
+        optimizer.zero_grad()
+        hologram = generate_complex_field(1., phase)
+        reconstruction = propagate_beam(
+                                        hologram, 
+                                        k, 
+                                        distance, 
+                                        pixel_pitch, 
+                                        wavelength, 
+                                        propagation_type, 
+                                        zero_padding = [True, False, True]
+                                       )
+        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2
+        loss = loss_function(reconstruction_intensity, target)
+        description = "Loss:{:.4f}".format(loss.item())
+        loss.backward(retain_graph = True)
+        optimizer.step()
+        t.set_description(description)
+    logging.warning(description)
+    torch.no_grad()
+    hologram = generate_complex_field(1., phase)
+    reconstruction = propagate_beam(
+                                    hologram, 
+                                    k, 
+                                    distance, 
+                                    pixel_pitch, 
+                                    wavelength, 
+                                    propagation_type, 
+                                    zero_padding = [True, False, True]
+                                   )
+    return hologram, reconstruction
+
+
+
+ +
+ +
+ + +

+ transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + zero_padding + – +
    +
               Zero pad in Fourier domain.
    +
    +
    +
  • +
  • + aperture + – +
    +
               Fourier domain aperture (e.g., pinhole in a typical holographic display).
    +           The default is one, but an aperture could be as large as input field [m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/classical.py +
def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : torch.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    zero_padding     : bool
+                       Zero pad in Fourier domain.
+    aperture         : torch.tensor
+                       Fourier domain aperture (e.g., pinhole in a typical holographic display).
+                       The default is one, but an aperture could be as large as input field [m x n].
+
+
+    Returns
+    -------
+    result           : torch.complex
+                       Final complex field (MxN).
+
+    """
+    H = get_propagation_kernel(
+                               nu = field.shape[-2], 
+                               nv = field.shape[-1], 
+                               dx = dx, 
+                               wavelength = wavelength, 
+                               distance = distance, 
+                               propagation_type = 'Transfer Function Fresnel',
+                               device = field.device
+                              )
+    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ blazed_grating(nx, ny, levels=2, axis='x') + +

+ + +
+ +

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
           Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
           Size of the output along Y.
    +
    +
    +
  • +
  • + levels + – +
    +
           Number of pixels.
    +
    +
    +
  • +
  • + axis + – +
    +
           Axis of glazed grating. It could be `x` or `y`.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/lens.py +
def blazed_grating(nx, ny, levels = 2, axis = 'x'):
+    """
+    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario García, et al. "High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings." Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. "High efficiency electrically-addressable phase-only spatial light modulator." optical review 6 (1999): 339-344.
+
+
+    Parameters
+    ----------
+    nx           : int
+                   Size of the output along X.
+    ny           : int
+                   Size of the output along Y.
+    levels       : int
+                   Number of pixels.
+    axis         : str
+                   Axis of glazed grating. It could be `x` or `y`.
+
+    """
+    if levels < 2:
+        levels = 2
+    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)
+    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    if axis == 'x':
+        blazed_grating = torch.exp(1j * X)
+    elif axis == 'y':
+        blazed_grating = torch.exp(1j * Y)
+    return blazed_grating
+
+
+
+ +
+ +
+ + +

+ linear_grating(nx, ny, every=2, add=None, axis='x') + +

+ + +
+ +

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + every + – +
    +
         Add the add value at every given number.
    +
    +
    +
  • +
  • + add + – +
    +
         Angle to be added.
    +
    +
    +
  • +
  • + axis + – +
    +
         Axis eiter X,Y or both.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( tensor +) – +
    +

    Linear grating term.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/lens.py +
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):
+    """
+    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    every      : int
+                 Add the add value at every given number.
+    add        : float
+                 Angle to be added.
+    axis       : string
+                 Axis eiter X,Y or both.
+
+    Returns
+    ----------
+    field      : torch.tensor
+                 Linear grating term.
+    """
+    if isinstance(add, type(None)):
+        add = np.pi
+    grating = torch.zeros((nx, ny), dtype=torch.complex64)
+    if axis == 'x':
+        grating[::every, :] = torch.exp(torch.tensor(1j*add))
+    if axis == 'y':
+        grating[:, ::every] = torch.exp(torch.tensor(1j*add))
+    if axis == 'xy':
+        checker = np.indices((nx, ny)).sum(axis=0) % every
+        checker = torch.from_numpy(checker)
+        checker += 1
+        checker = checker % 2
+        grating = torch.exp(1j*checker*add)
+    return grating
+
+
+
+ +
+ +
+ + +

+ prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0) + +

+ + +
+ +

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
           Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
           Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
           See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + angle + – +
    +
           Tilt angle of the prism in degrees.
    +
    +
    +
  • +
  • + dx + – +
    +
           Pixel pitch.
    +
    +
    +
  • +
  • + axis + – +
    +
           Axis of the prism.
    +
    +
    +
  • +
  • + phase_offset + (float, default: + 0.0 +) + – +
    +
           Phase offset in angles. Default is zero.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +prism ( tensor +) – +
    +

    Generated phase function for a prism.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/lens.py +
40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):
+    """
+    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engström, David, et al. "Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator." Optics express 16.22 (2008): 18275-18287. for more.
+
+    Parameters
+    ----------
+    nx           : int
+                   Size of the output along X.
+    ny           : int
+                   Size of the output along Y.
+    k            : odak.wave.wavenumber
+                   See odak.wave.wavenumber for more.
+    angle        : float
+                   Tilt angle of the prism in degrees.
+    dx           : float
+                   Pixel pitch.
+    axis         : str
+                   Axis of the prism.
+    phase_offset : float
+                   Phase offset in angles. Default is zero.
+
+    Returns
+    ----------
+    prism        : torch.tensor
+                   Generated phase function for a prism.
+    """
+    angle = torch.deg2rad(torch.tensor([angle]))
+    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))
+    x = torch.arange(0, nx) * dx
+    y = torch.arange(0, ny) * dx
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    if axis == 'y':
+        phase = k * torch.sin(angle) * Y + phase_offset
+        prism = torch.exp(-1j * phase)
+    elif axis == 'x':
+        phase = k * torch.sin(angle) * X + phase_offset
+        prism = torch.exp(-1j * phase)
+    return prism
+
+
+
+ +
+ +
+ + +

+ quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]) + +

+ + +
+ +

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + focal + – +
    +
         Focal length of the quadratic phase function.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
  • + offset + – +
    +
         Deviation from the center along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +function ( tensor +) – +
    +

    Generated quadratic phase function.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/lens.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
+    """ 
+    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    focal      : float
+                 Focal length of the quadratic phase function.
+    dx         : float
+                 Pixel pitch.
+    offset     : list
+                 Deviation from the center along X and Y axes.
+
+    Returns
+    -------
+    function   : torch.tensor
+                 Generated quadratic phase function.
+    """
+    size = [nx, ny]
+    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx
+    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx
+    X, Y = torch.meshgrid(x, y, indexing='ij')
+    Z = X**2 + Y**2
+    qwf = torch.exp(-0.5j * k / focal * Z)
+    return qwf
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ multiplane_loss + + +

+ + +
+ + +

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

+ + + + + + +
+ Source code in odak/learn/wave/loss.py +
class multiplane_loss():
+    """
+    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
+    """
+
+    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
+                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
+                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
+        """
+        Parameters
+        ----------
+        target_image      : torch.tensor
+                            Color target image [3 x m x n].
+        target_depth      : torch.tensor
+                            Monochrome target depth, same resolution as target_image.
+        target_blur_size  : int
+                            Maximum target blur size.
+        blur_ratio        : float
+                            Blur ratio, a value between zero and one.
+        number_of_planes  : int
+                            Number of planes.
+        weights           : list
+                            Weights of the loss function.
+        multiplier        : float
+                            Multiplier to multipy with targets.
+        scheme            : str
+                            The type of the loss, `naive` without defocus or `defocus` with defocus.
+        reduction         : str
+                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
+        device            : torch.device
+                            Device to be used (e.g., cuda, cpu, opencl).
+        """
+        self.device = device
+        self.target_image     = target_image.float().to(self.device)
+        self.target_depth     = target_depth.float().to(self.device)
+        self.target_blur_size = target_blur_size
+        if self.target_blur_size % 2 == 0:
+            self.target_blur_size += 1
+        self.number_of_planes = number_of_planes
+        self.multiplier       = multiplier
+        self.weights          = weights
+        self.reduction        = reduction
+        self.blur_ratio       = blur_ratio
+        self.set_targets()
+        if scheme == 'defocus':
+            self.add_defocus_blur()
+        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)
+
+    def get_targets(self):
+        """
+        Returns
+        -------
+        targets           : torch.tensor
+                            Returns a copy of the targets.
+        target_depth      : torch.tensor
+                            Returns a copy of the normalized quantized depth map.
+
+        """
+        divider = self.number_of_planes - 1
+        if divider == 0:
+            divider = 1
+        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider
+
+
+    def set_targets(self):
+        """
+        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
+        """
+        self.target_depth = self.target_depth * (self.number_of_planes - 1)
+        self.target_depth = torch.round(self.target_depth, decimals = 0)
+        self.targets      = torch.zeros(
+                                        self.number_of_planes,
+                                        self.target_image.shape[0],
+                                        self.target_image.shape[1],
+                                        self.target_image.shape[2],
+                                        requires_grad = False,
+                                        device = self.device
+                                       )
+        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
+        self.masks        = torch.zeros_like(self.targets)
+        for i in range(self.number_of_planes):
+            for ch in range(self.target_image.shape[0]):
+                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
+                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
+                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
+                new_target = self.target_image[ch] * mask
+                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
+                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
+                self.masks[i, ch] = mask.detach().clone() 
+
+
+    def add_defocus_blur(self):
+        """
+        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
+        """
+        kernel_length = [self.target_blur_size, self.target_blur_size ]
+        for ch in range(self.target_image.shape[0]):
+            targets_cache = self.targets[:, ch].detach().clone()
+            target = torch.sum(targets_cache, axis = 0)
+            for i in range(self.number_of_planes):
+                defocus = torch.zeros_like(targets_cache[i])
+                for j in range(self.number_of_planes):
+                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
+                    if torch.sum(targets_cache[j]) > 0:
+                        if i == j:
+                            nsigma = [0., 0.]
+                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
+                        kernel = kernel / torch.sum(kernel)
+                        kernel = kernel.unsqueeze(0).unsqueeze(0)
+                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
+                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
+                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
+                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
+                self.targets[i, ch] = defocus
+        self.targets = self.targets.detach().clone() * self.multiplier
+
+
+    def __call__(self, image, target, plane_id = None):
+        """
+        Calculates the multiplane loss against a given target.
+
+        Parameters
+        ----------
+        image         : torch.tensor
+                        Image to compare with a target [3 x m x n].
+        target        : torch.tensor
+                        Target image for comparison [3 x m x n].
+        plane_id      : int
+                        Number of the plane under test.
+
+        Returns
+        -------
+        loss          : torch.tensor
+                        Computed loss.
+        """
+        l2 = self.weights[0] * self.loss_function(image, target)
+        if isinstance(plane_id, type(None)):
+            mask = self.masks
+        else:
+            mask= self.masks[plane_id, :]
+        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
+        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
+        loss = l2 + l2_mask + l2_cor
+        return loss
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, plane_id=None) + +

+ + +
+ +

Calculates the multiplane loss against a given target.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to compare with a target [3 x m x n].
    +
    +
    +
  • +
  • + target + – +
    +
            Target image for comparison [3 x m x n].
    +
    +
    +
  • +
  • + plane_id + – +
    +
            Number of the plane under test.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    Computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def __call__(self, image, target, plane_id = None):
+    """
+    Calculates the multiplane loss against a given target.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to compare with a target [3 x m x n].
+    target        : torch.tensor
+                    Target image for comparison [3 x m x n].
+    plane_id      : int
+                    Number of the plane under test.
+
+    Returns
+    -------
+    loss          : torch.tensor
+                    Computed loss.
+    """
+    l2 = self.weights[0] * self.loss_function(image, target)
+    if isinstance(plane_id, type(None)):
+        mask = self.masks
+    else:
+        mask= self.masks[plane_id, :]
+    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)
+    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)
+    loss = l2 + l2_mask + l2_cor
+    return loss
+
+
+
+ +
+ +
+ + +

+ __init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + target_image + – +
    +
                Color target image [3 x m x n].
    +
    +
    +
  • +
  • + target_depth + – +
    +
                Monochrome target depth, same resolution as target_image.
    +
    +
    +
  • +
  • + target_blur_size + – +
    +
                Maximum target blur size.
    +
    +
    +
  • +
  • + blur_ratio + – +
    +
                Blur ratio, a value between zero and one.
    +
    +
    +
  • +
  • + number_of_planes + – +
    +
                Number of planes.
    +
    +
    +
  • +
  • + weights + – +
    +
                Weights of the loss function.
    +
    +
    +
  • +
  • + multiplier + – +
    +
                Multiplier to multipy with targets.
    +
    +
    +
  • +
  • + scheme + – +
    +
                The type of the loss, `naive` without defocus or `defocus` with defocus.
    +
    +
    +
  • +
  • + reduction + – +
    +
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    +
    +
    +
  • +
  • + device + – +
    +
                Device to be used (e.g., cuda, cpu, opencl).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
+             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], 
+             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):
+    """
+    Parameters
+    ----------
+    target_image      : torch.tensor
+                        Color target image [3 x m x n].
+    target_depth      : torch.tensor
+                        Monochrome target depth, same resolution as target_image.
+    target_blur_size  : int
+                        Maximum target blur size.
+    blur_ratio        : float
+                        Blur ratio, a value between zero and one.
+    number_of_planes  : int
+                        Number of planes.
+    weights           : list
+                        Weights of the loss function.
+    multiplier        : float
+                        Multiplier to multipy with targets.
+    scheme            : str
+                        The type of the loss, `naive` without defocus or `defocus` with defocus.
+    reduction         : str
+                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
+    device            : torch.device
+                        Device to be used (e.g., cuda, cpu, opencl).
+    """
+    self.device = device
+    self.target_image     = target_image.float().to(self.device)
+    self.target_depth     = target_depth.float().to(self.device)
+    self.target_blur_size = target_blur_size
+    if self.target_blur_size % 2 == 0:
+        self.target_blur_size += 1
+    self.number_of_planes = number_of_planes
+    self.multiplier       = multiplier
+    self.weights          = weights
+    self.reduction        = reduction
+    self.blur_ratio       = blur_ratio
+    self.set_targets()
+    if scheme == 'defocus':
+        self.add_defocus_blur()
+    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)
+
+
+
+ +
+ +
+ + +

+ add_defocus_blur() + +

+ + +
+ +

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

+ +
+ Source code in odak/learn/wave/loss.py +
def add_defocus_blur(self):
+    """
+    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
+    """
+    kernel_length = [self.target_blur_size, self.target_blur_size ]
+    for ch in range(self.target_image.shape[0]):
+        targets_cache = self.targets[:, ch].detach().clone()
+        target = torch.sum(targets_cache, axis = 0)
+        for i in range(self.number_of_planes):
+            defocus = torch.zeros_like(targets_cache[i])
+            for j in range(self.number_of_planes):
+                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
+                if torch.sum(targets_cache[j]) > 0:
+                    if i == j:
+                        nsigma = [0., 0.]
+                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
+                    kernel = kernel / torch.sum(kernel)
+                    kernel = kernel.unsqueeze(0).unsqueeze(0)
+                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
+                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
+                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
+                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
+            self.targets[i, ch] = defocus
+    self.targets = self.targets.detach().clone() * self.multiplier
+
+
+
+ +
+ +
+ + +

+ get_targets() + +

+ + +
+ + + +

Returns:

+
    +
  • +targets ( tensor +) – +
    +

    Returns a copy of the targets.

    +
    +
  • +
  • +target_depth ( tensor +) – +
    +

    Returns a copy of the normalized quantized depth map.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def get_targets(self):
+    """
+    Returns
+    -------
+    targets           : torch.tensor
+                        Returns a copy of the targets.
+    target_depth      : torch.tensor
+                        Returns a copy of the normalized quantized depth map.
+
+    """
+    divider = self.number_of_planes - 1
+    if divider == 0:
+        divider = 1
+    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider
+
+
+
+ +
+ +
+ + +

+ set_targets() + +

+ + +
+ +

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

+ +
+ Source code in odak/learn/wave/loss.py +
def set_targets(self):
+    """
+    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
+    """
+    self.target_depth = self.target_depth * (self.number_of_planes - 1)
+    self.target_depth = torch.round(self.target_depth, decimals = 0)
+    self.targets      = torch.zeros(
+                                    self.number_of_planes,
+                                    self.target_image.shape[0],
+                                    self.target_image.shape[1],
+                                    self.target_image.shape[2],
+                                    requires_grad = False,
+                                    device = self.device
+                                   )
+    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
+    self.masks        = torch.zeros_like(self.targets)
+    for i in range(self.number_of_planes):
+        for ch in range(self.target_image.shape[0]):
+            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
+            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
+            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
+            new_target = self.target_image[ch] * mask
+            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
+            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
+            self.masks[i, ch] = mask.detach().clone() 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ perceptual_multiplane_loss + + +

+ + +
+ + +

Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

+ + + + + + +
+ Source code in odak/learn/wave/loss.py +
class perceptual_multiplane_loss():
+    """
+    Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.
+    """
+
+    def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
+                 target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
+                 base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
+                 additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
+        """
+        Parameters
+        ----------
+        target_image            : torch.tensor
+                                    Color target image [3 x m x n].
+        target_depth            : torch.tensor
+                                    Monochrome target depth, same resolution as target_image.
+        target_blur_size        : int
+                                    Maximum target blur size.
+        blur_ratio              : float
+                                    Blur ratio, a value between zero and one.
+        number_of_planes        : int
+                                    Number of planes.
+        multiplier              : float
+                                    Multiplier to multipy with targets.
+        scheme                  : str
+                                    The type of the loss, `naive` without defocus or `defocus` with defocus.
+        base_loss_weights       : list
+                                    Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
+        additional_loss_weights : dict
+                                    Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
+        reduction               : str
+                                    Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
+        return_components       : bool
+                                    If True (False by default), returns the components of the loss as a dict.
+        device                  : torch.device
+                                    Device to be used (e.g., cuda, cpu, opencl).
+        """
+        self.device = device
+        self.target_image     = target_image.float().to(self.device)
+        self.target_depth     = target_depth.float().to(self.device)
+        self.target_blur_size = target_blur_size
+        if self.target_blur_size % 2 == 0:
+            self.target_blur_size += 1
+        self.number_of_planes = number_of_planes
+        self.multiplier       = multiplier
+        self.reduction        = reduction
+        if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
+            logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
+            self.reduction = 'mean'
+        self.blur_ratio       = blur_ratio
+        self.set_targets()
+        if scheme == 'defocus':
+            self.add_defocus_blur()
+        self.base_loss_weights = base_loss_weights
+        self.additional_loss_weights = additional_loss_weights
+        self.return_components = return_components
+        self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
+        self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
+        for key in self.additional_loss_weights.keys():
+            if key == 'cvvdp':
+                self.cvvdp = CVVDP()
+            if key == 'fvvdp':
+                self.fvvdp = FVVDP()
+            if key == 'lpips':
+                self.lpips = LPIPS()
+            if key == 'psnr':
+                self.psnr = PSNR()
+            if key == 'ssim':
+                self.ssim = SSIM()
+            if key == 'msssim':
+                self.msssim = MSSSIM()
+
+    def get_targets(self):
+        """
+        Returns
+        -------
+        targets           : torch.tensor
+                            Returns a copy of the targets.
+        target_depth      : torch.tensor
+                            Returns a copy of the normalized quantized depth map.
+
+        """
+        divider = self.number_of_planes - 1
+        if divider == 0:
+            divider = 1
+        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider
+
+
+    def set_targets(self):
+        """
+        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
+        """
+        self.target_depth = self.target_depth * (self.number_of_planes - 1)
+        self.target_depth = torch.round(self.target_depth, decimals = 0)
+        self.targets      = torch.zeros(
+                                        self.number_of_planes,
+                                        self.target_image.shape[0],
+                                        self.target_image.shape[1],
+                                        self.target_image.shape[2],
+                                        requires_grad = False,
+                                        device = self.device
+                                       )
+        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
+        self.masks        = torch.zeros_like(self.targets)
+        for i in range(self.number_of_planes):
+            for ch in range(self.target_image.shape[0]):
+                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
+                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
+                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
+                new_target = self.target_image[ch] * mask
+                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
+                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
+                self.masks[i, ch] = mask.detach().clone() 
+
+
+    def add_defocus_blur(self):
+        """
+        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
+        """
+        kernel_length = [self.target_blur_size, self.target_blur_size ]
+        for ch in range(self.target_image.shape[0]):
+            targets_cache = self.targets[:, ch].detach().clone()
+            target = torch.sum(targets_cache, axis = 0)
+            for i in range(self.number_of_planes):
+                defocus = torch.zeros_like(targets_cache[i])
+                for j in range(self.number_of_planes):
+                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
+                    if torch.sum(targets_cache[j]) > 0:
+                        if i == j:
+                            nsigma = [0., 0.]
+                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
+                        kernel = kernel / torch.sum(kernel)
+                        kernel = kernel.unsqueeze(0).unsqueeze(0)
+                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
+                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
+                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
+                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
+                self.targets[i, ch] = defocus
+        self.targets = self.targets.detach().clone() * self.multiplier
+
+
+    def __call__(self, image, target, plane_id = None):
+        """
+        Calculates the multiplane loss against a given target.
+
+        Parameters
+        ----------
+        image         : torch.tensor
+                        Image to compare with a target [3 x m x n].
+        target        : torch.tensor
+                        Target image for comparison [3 x m x n].
+        plane_id      : int
+                        Number of the plane under test.
+
+        Returns
+        -------
+        loss          : torch.tensor
+                        Computed loss.
+        """
+        loss_components = {}
+        if isinstance(plane_id, type(None)):
+            mask = self.masks
+        else:
+            mask= self.masks[plane_id, :]
+        l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
+        l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
+        l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
+        loss_components['l2'] = l2
+        loss_components['l2_mask'] = l2_mask
+        loss_components['l2_cor'] = l2_cor
+
+        l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
+        l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
+        l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
+        loss_components['l1'] = l1
+        loss_components['l1_mask'] = l1_mask
+        loss_components['l1_cor'] = l1_cor
+
+        for key in self.additional_loss_weights.keys():
+            if key == 'cvvdp':
+                loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
+                loss_components['cvvdp'] = loss_cvvdp
+            if key == 'fvvdp':
+                loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
+                loss_components['fvvdp'] = loss_fvvdp
+            if key == 'lpips':
+                loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
+                loss_components['lpips'] = loss_lpips
+            if key == 'psnr':
+                loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
+                loss_components['psnr'] = loss_psnr
+            if key == 'ssim':
+                loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
+                loss_components['ssim'] = loss_ssim
+            if key == 'msssim':
+                loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
+                loss_components['msssim'] = loss_msssim
+
+        loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)
+
+        if self.return_components:
+            return loss, loss_components
+        return loss
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(image, target, plane_id=None) + +

+ + +
+ +

Calculates the multiplane loss against a given target.

+ + +

Parameters:

+
    +
  • + image + – +
    +
            Image to compare with a target [3 x m x n].
    +
    +
    +
  • +
  • + target + – +
    +
            Target image for comparison [3 x m x n].
    +
    +
    +
  • +
  • + plane_id + – +
    +
            Number of the plane under test.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss ( tensor +) – +
    +

    Computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def __call__(self, image, target, plane_id = None):
+    """
+    Calculates the multiplane loss against a given target.
+
+    Parameters
+    ----------
+    image         : torch.tensor
+                    Image to compare with a target [3 x m x n].
+    target        : torch.tensor
+                    Target image for comparison [3 x m x n].
+    plane_id      : int
+                    Number of the plane under test.
+
+    Returns
+    -------
+    loss          : torch.tensor
+                    Computed loss.
+    """
+    loss_components = {}
+    if isinstance(plane_id, type(None)):
+        mask = self.masks
+    else:
+        mask= self.masks[plane_id, :]
+    l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)
+    l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)
+    l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)
+    loss_components['l2'] = l2
+    loss_components['l2_mask'] = l2_mask
+    loss_components['l2_cor'] = l2_cor
+
+    l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)
+    l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)
+    l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)
+    loss_components['l1'] = l1
+    loss_components['l1_mask'] = l1_mask
+    loss_components['l1_cor'] = l1_cor
+
+    for key in self.additional_loss_weights.keys():
+        if key == 'cvvdp':
+            loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)
+            loss_components['cvvdp'] = loss_cvvdp
+        if key == 'fvvdp':
+            loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)
+            loss_components['fvvdp'] = loss_fvvdp
+        if key == 'lpips':
+            loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)
+            loss_components['lpips'] = loss_lpips
+        if key == 'psnr':
+            loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)
+            loss_components['psnr'] = loss_psnr
+        if key == 'ssim':
+            loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)
+            loss_components['ssim'] = loss_ssim
+        if key == 'msssim':
+            loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)
+            loss_components['msssim'] = loss_msssim
+
+    loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)
+
+    if self.return_components:
+        return loss, loss_components
+    return loss
+
+
+
+ +
+ +
+ + +

+ __init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, multiplier=1.0, scheme='defocus', base_loss_weights={'base_l2_loss': 1.0, 'loss_l2_mask': 1.0, 'loss_l2_cor': 1.0, 'base_l1_loss': 1.0, 'loss_l1_mask': 1.0, 'loss_l1_cor': 1.0}, additional_loss_weights={'cvvdp': 1.0}, reduction='mean', return_components=False, device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + target_image + – +
    +
                        Color target image [3 x m x n].
    +
    +
    +
  • +
  • + target_depth + – +
    +
                        Monochrome target depth, same resolution as target_image.
    +
    +
    +
  • +
  • + target_blur_size + – +
    +
                        Maximum target blur size.
    +
    +
    +
  • +
  • + blur_ratio + – +
    +
                        Blur ratio, a value between zero and one.
    +
    +
    +
  • +
  • + number_of_planes + – +
    +
                        Number of planes.
    +
    +
    +
  • +
  • + multiplier + – +
    +
                        Multiplier to multipy with targets.
    +
    +
    +
  • +
  • + scheme + – +
    +
                        The type of the loss, `naive` without defocus or `defocus` with defocus.
    +
    +
    +
  • +
  • + base_loss_weights + – +
    +
                        Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
    +
    +
    +
  • +
  • + additional_loss_weights + (dict, default: + {'cvvdp': 1.0} +) + – +
    +
                        Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
    +
    +
    +
  • +
  • + reduction + – +
    +
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
    +
    +
    +
  • +
  • + return_components + – +
    +
                        If True (False by default), returns the components of the loss as a dict.
    +
    +
    +
  • +
  • + device + – +
    +
                        Device to be used (e.g., cuda, cpu, opencl).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def __init__(self, target_image, target_depth, blur_ratio = 0.25, 
+             target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', 
+             base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},
+             additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):
+    """
+    Parameters
+    ----------
+    target_image            : torch.tensor
+                                Color target image [3 x m x n].
+    target_depth            : torch.tensor
+                                Monochrome target depth, same resolution as target_image.
+    target_blur_size        : int
+                                Maximum target blur size.
+    blur_ratio              : float
+                                Blur ratio, a value between zero and one.
+    number_of_planes        : int
+                                Number of planes.
+    multiplier              : float
+                                Multiplier to multipy with targets.
+    scheme                  : str
+                                The type of the loss, `naive` without defocus or `defocus` with defocus.
+    base_loss_weights       : list
+                                Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.
+    additional_loss_weights : dict
+                                Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.
+    reduction               : str
+                                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
+    return_components       : bool
+                                If True (False by default), returns the components of the loss as a dict.
+    device                  : torch.device
+                                Device to be used (e.g., cuda, cpu, opencl).
+    """
+    self.device = device
+    self.target_image     = target_image.float().to(self.device)
+    self.target_depth     = target_depth.float().to(self.device)
+    self.target_blur_size = target_blur_size
+    if self.target_blur_size % 2 == 0:
+        self.target_blur_size += 1
+    self.number_of_planes = number_of_planes
+    self.multiplier       = multiplier
+    self.reduction        = reduction
+    if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:
+        logging.warning("Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.")
+        self.reduction = 'mean'
+    self.blur_ratio       = blur_ratio
+    self.set_targets()
+    if scheme == 'defocus':
+        self.add_defocus_blur()
+    self.base_loss_weights = base_loss_weights
+    self.additional_loss_weights = additional_loss_weights
+    self.return_components = return_components
+    self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)
+    self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)
+    for key in self.additional_loss_weights.keys():
+        if key == 'cvvdp':
+            self.cvvdp = CVVDP()
+        if key == 'fvvdp':
+            self.fvvdp = FVVDP()
+        if key == 'lpips':
+            self.lpips = LPIPS()
+        if key == 'psnr':
+            self.psnr = PSNR()
+        if key == 'ssim':
+            self.ssim = SSIM()
+        if key == 'msssim':
+            self.msssim = MSSSIM()
+
+
+
+ +
+ +
+ + +

+ add_defocus_blur() + +

+ + +
+ +

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

+ +
+ Source code in odak/learn/wave/loss.py +
def add_defocus_blur(self):
+    """
+    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.
+    """
+    kernel_length = [self.target_blur_size, self.target_blur_size ]
+    for ch in range(self.target_image.shape[0]):
+        targets_cache = self.targets[:, ch].detach().clone()
+        target = torch.sum(targets_cache, axis = 0)
+        for i in range(self.number_of_planes):
+            defocus = torch.zeros_like(targets_cache[i])
+            for j in range(self.number_of_planes):
+                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]
+                if torch.sum(targets_cache[j]) > 0:
+                    if i == j:
+                        nsigma = [0., 0.]
+                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)
+                    kernel = kernel / torch.sum(kernel)
+                    kernel = kernel.unsqueeze(0).unsqueeze(0)
+                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)
+                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')
+                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])
+                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])
+            self.targets[i, ch] = defocus
+    self.targets = self.targets.detach().clone() * self.multiplier
+
+
+
+ +
+ +
+ + +

+ get_targets() + +

+ + +
+ + + +

Returns:

+
    +
  • +targets ( tensor +) – +
    +

    Returns a copy of the targets.

    +
    +
  • +
  • +target_depth ( tensor +) – +
    +

    Returns a copy of the normalized quantized depth map.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def get_targets(self):
+    """
+    Returns
+    -------
+    targets           : torch.tensor
+                        Returns a copy of the targets.
+    target_depth      : torch.tensor
+                        Returns a copy of the normalized quantized depth map.
+
+    """
+    divider = self.number_of_planes - 1
+    if divider == 0:
+        divider = 1
+    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider
+
+
+
+ +
+ +
+ + +

+ set_targets() + +

+ + +
+ +

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

+ +
+ Source code in odak/learn/wave/loss.py +
def set_targets(self):
+    """
+    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.
+    """
+    self.target_depth = self.target_depth * (self.number_of_planes - 1)
+    self.target_depth = torch.round(self.target_depth, decimals = 0)
+    self.targets      = torch.zeros(
+                                    self.number_of_planes,
+                                    self.target_image.shape[0],
+                                    self.target_image.shape[1],
+                                    self.target_image.shape[2],
+                                    requires_grad = False,
+                                    device = self.device
+                                   )
+    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)
+    self.masks        = torch.zeros_like(self.targets)
+    for i in range(self.number_of_planes):
+        for ch in range(self.target_image.shape[0]):
+            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)
+            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)
+            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)
+            new_target = self.target_image[ch] * mask
+            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()
+            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)
+            self.masks[i, ch] = mask.detach().clone() 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ phase_gradient + + +

+ + +
+

+ Bases: Module

+ + +

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

+

This implements a convolution of the phase with a kernel.

+

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

+ + + + + + +
+ Source code in odak/learn/wave/loss.py +
11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
class phase_gradient(nn.Module):
+
+    """
+    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. 
+
+    This implements a convolution of the phase with a kernel.
+
+    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.
+    """
+
+
+    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
+        """
+        Parameters
+        ----------
+        kernel                  : torch.tensor
+                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.
+        loss                    : torch.nn.Module
+                                    loss function, L2 Loss by default.
+        """
+        super(phase_gradient, self).__init__()
+        self.device = device
+        self.loss = loss
+        if kernel == None:
+            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
+        else:
+            if len(kernel.shape) == 4:
+                self.kernel = kernel
+            else:
+                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
+        self.kernel = Variable(self.kernel.to(self.device))
+
+
+    def forward(self, phase):
+        """
+        Calculates the phase gradient Loss.
+
+        Parameters
+        ----------
+        phase                  : torch.tensor
+                                    Phase of the complex amplitude.
+
+        Returns
+        -------
+
+        loss_value              : torch.tensor
+                                    The computed loss.
+        """
+
+        if len(phase.shape) == 2:
+            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
+        edge_detect = self.functional_conv2d(phase)
+        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
+        return loss_value
+
+
+    def functional_conv2d(self, phase):
+        """
+        Calculates the gradient of the phase.
+
+        Parameters
+        ----------
+        phase                  : torch.tensor
+                                    Phase of the complex amplitude.
+
+        Returns
+        -------
+
+        edge_detect              : torch.tensor
+                                    The computed phase gradient.
+        """
+        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
+        return edge_detect
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + kernel + – +
    +
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.
    +
    +
    +
  • +
  • + loss + – +
    +
                        loss function, L2 Loss by default.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device("cpu")):
+    """
+    Parameters
+    ----------
+    kernel                  : torch.tensor
+                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.
+    loss                    : torch.nn.Module
+                                loss function, L2 Loss by default.
+    """
+    super(phase_gradient, self).__init__()
+    self.device = device
+    self.loss = loss
+    if kernel == None:
+        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8
+    else:
+        if len(kernel.shape) == 4:
+            self.kernel = kernel
+        else:
+            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))
+    self.kernel = Variable(self.kernel.to(self.device))
+
+
+
+ +
+ +
+ + +

+ forward(phase) + +

+ + +
+ +

Calculates the phase gradient Loss.

+ + +

Parameters:

+
    +
  • + phase + – +
    +
                        Phase of the complex amplitude.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss_value ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
def forward(self, phase):
+    """
+    Calculates the phase gradient Loss.
+
+    Parameters
+    ----------
+    phase                  : torch.tensor
+                                Phase of the complex amplitude.
+
+    Returns
+    -------
+
+    loss_value              : torch.tensor
+                                The computed loss.
+    """
+
+    if len(phase.shape) == 2:
+        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))
+    edge_detect = self.functional_conv2d(phase)
+    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))
+    return loss_value
+
+
+
+ +
+ +
+ + +

+ functional_conv2d(phase) + +

+ + +
+ +

Calculates the gradient of the phase.

+ + +

Parameters:

+
    +
  • + phase + – +
    +
                        Phase of the complex amplitude.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +edge_detect ( tensor +) – +
    +

    The computed phase gradient.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def functional_conv2d(self, phase):
+    """
+    Calculates the gradient of the phase.
+
+    Parameters
+    ----------
+    phase                  : torch.tensor
+                                Phase of the complex amplitude.
+
+    Returns
+    -------
+
+    edge_detect              : torch.tensor
+                                The computed phase gradient.
+    """
+    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)
+    return edge_detect
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ speckle_contrast + + +

+ + +
+

+ Bases: Module

+ + +

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

+

We refer to the following paper:

+

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

+ + + + + + +
+ Source code in odak/learn/wave/loss.py +
class speckle_contrast(nn.Module):
+
+    """
+    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.
+
+    We refer to the following paper:
+
+    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. 
+    """
+
+
+    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
+        """
+        Parameters
+        ----------
+        kernel_size             : torch.tensor
+                                    Convolution filter kernel size, 11 by 11 average kernel by default.
+        step_size               : tuple
+                                    Convolution stride in height and width direction.
+        loss                    : torch.nn.Module
+                                    loss function, L2 Loss by default.
+        """
+        super(speckle_contrast, self).__init__()
+        self.device = device
+        self.loss = loss
+        self.step_size = step_size
+        self.kernel_size = kernel_size
+        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
+        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))
+
+
+    def forward(self, intensity):
+        """
+        Calculates the speckle contrast Loss.
+
+        Parameters
+        ----------
+        intensity               : torch.tensor
+                                    intensity of the complex amplitude.
+
+        Returns
+        -------
+
+        loss_value              : torch.tensor
+                                    The computed loss.
+        """
+
+        if len(intensity.shape) == 2:
+            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
+        Speckle_C = self.functional_conv2d(intensity)
+        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
+        return loss_value
+
+
+    def functional_conv2d(self, intensity):
+        """
+        Calculates the speckle contrast of the intensity.
+
+        Parameters
+        ----------
+        intensity                : torch.tensor
+                                    Intensity of the complex field.
+
+        Returns
+        -------
+
+        Speckle_C               : torch.tensor
+                                    The computed speckle contrast.
+        """
+        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
+        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
+        Speckle_C = var / mean
+        return Speckle_C
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + kernel_size + – +
    +
                        Convolution filter kernel size, 11 by 11 average kernel by default.
    +
    +
    +
  • +
  • + step_size + – +
    +
                        Convolution stride in height and width direction.
    +
    +
    +
  • +
  • + loss + – +
    +
                        loss function, L2 Loss by default.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device("cpu")):
+    """
+    Parameters
+    ----------
+    kernel_size             : torch.tensor
+                                Convolution filter kernel size, 11 by 11 average kernel by default.
+    step_size               : tuple
+                                Convolution stride in height and width direction.
+    loss                    : torch.nn.Module
+                                loss function, L2 Loss by default.
+    """
+    super(speckle_contrast, self).__init__()
+    self.device = device
+    self.loss = loss
+    self.step_size = step_size
+    self.kernel_size = kernel_size
+    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)
+    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))
+
+
+
+ +
+ +
+ + +

+ forward(intensity) + +

+ + +
+ +

Calculates the speckle contrast Loss.

+ + +

Parameters:

+
    +
  • + intensity + – +
    +
                        intensity of the complex amplitude.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +loss_value ( tensor +) – +
    +

    The computed loss.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def forward(self, intensity):
+    """
+    Calculates the speckle contrast Loss.
+
+    Parameters
+    ----------
+    intensity               : torch.tensor
+                                intensity of the complex amplitude.
+
+    Returns
+    -------
+
+    loss_value              : torch.tensor
+                                The computed loss.
+    """
+
+    if len(intensity.shape) == 2:
+        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))
+    Speckle_C = self.functional_conv2d(intensity)
+    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))
+    return loss_value
+
+
+
+ +
+ +
+ + +

+ functional_conv2d(intensity) + +

+ + +
+ +

Calculates the speckle contrast of the intensity.

+ + +

Parameters:

+
    +
  • + intensity + – +
    +
                        Intensity of the complex field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +Speckle_C ( tensor +) – +
    +

    The computed speckle contrast.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/loss.py +
def functional_conv2d(self, intensity):
+    """
+    Calculates the speckle contrast of the intensity.
+
+    Parameters
+    ----------
+    intensity                : torch.tensor
+                                Intensity of the complex field.
+
+    Returns
+    -------
+
+    Speckle_C               : torch.tensor
+                                The computed speckle contrast.
+    """
+    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)
+    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))
+    Speckle_C = var / mean
+    return Speckle_C
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ channel_gate + + +

+ + +
+

+ Bases: Module

+ + +

Channel attention module with various pooling strategies. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class channel_gate(torch.nn.Module):
+    """
+    Channel attention module with various pooling strategies.
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max']
+                ):
+        """
+        Initializes the channel gate module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the intermediate layer.
+        pool_types      : list
+                          List of pooling operations to apply.
+        """
+        super().__init__()
+        self.gate_channels = gate_channels
+        hidden_channels = gate_channels // reduction_ratio
+        if hidden_channels == 0:
+            hidden_channels = 1
+        self.mlp = torch.nn.Sequential(
+                                       convolutional_block_attention.Flatten(),
+                                       torch.nn.Linear(gate_channels, hidden_channels),
+                                       torch.nn.ReLU(),
+                                       torch.nn.Linear(hidden_channels, gate_channels)
+                                      )
+        self.pool_types = pool_types
+
+
+    def forward(self, x):
+        """
+        Forward pass of the ChannelGate module.
+
+        Applies channel-wise attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the ChannelGate module.
+
+        Returns
+        -------
+        output       : torch.tensor
+                       Output tensor after applying channel attention.
+        """
+        channel_att_sum = None
+        for pool_type in self.pool_types:
+            if pool_type == 'avg':
+                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            elif pool_type == 'max':
+                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+            channel_att_raw = self.mlp(pool)
+            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+        output = x * scale
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max']) + +

+ + +
+ +

Initializes the channel gate module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the intermediate layer.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max']
+            ):
+    """
+    Initializes the channel gate module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the intermediate layer.
+    pool_types      : list
+                      List of pooling operations to apply.
+    """
+    super().__init__()
+    self.gate_channels = gate_channels
+    hidden_channels = gate_channels // reduction_ratio
+    if hidden_channels == 0:
+        hidden_channels = 1
+    self.mlp = torch.nn.Sequential(
+                                   convolutional_block_attention.Flatten(),
+                                   torch.nn.Linear(gate_channels, hidden_channels),
+                                   torch.nn.ReLU(),
+                                   torch.nn.Linear(hidden_channels, gate_channels)
+                                  )
+    self.pool_types = pool_types
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the ChannelGate module.

+

Applies channel-wise attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the ChannelGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor after applying channel attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the ChannelGate module.
+
+    Applies channel-wise attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the ChannelGate module.
+
+    Returns
+    -------
+    output       : torch.tensor
+                   Output tensor after applying channel attention.
+    """
+    channel_att_sum = None
+    for pool_type in self.pool_types:
+        if pool_type == 'avg':
+            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        elif pool_type == 'max':
+            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
+        channel_att_raw = self.mlp(pool)
+        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw
+    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
+    output = x * scale
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolution_layer + + +

+ + +
+

+ Bases: Module

+ + +

A convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolution_layer(torch.nn.Module):
+    """
+    A convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 bias = False,
+                 stride = 1,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        layers = [
+            torch.nn.Conv2d(
+                            input_channels,
+                            output_channels,
+                            kernel_size = kernel_size,
+                            stride = stride,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           )
+        ]
+        if normalization:
+            layers.append(torch.nn.BatchNorm2d(output_channels))
+        if activation:
+            layers.append(activation)
+        self.model = torch.nn.Sequential(*layers)
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             bias = False,
+             stride = 1,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    layers = [
+        torch.nn.Conv2d(
+                        input_channels,
+                        output_channels,
+                        kernel_size = kernel_size,
+                        stride = stride,
+                        padding = kernel_size // 2,
+                        bias = bias
+                       )
+    ]
+    if normalization:
+        layers.append(torch.nn.BatchNorm2d(output_channels))
+    if activation:
+        layers.append(activation)
+    self.model = torch.nn.Sequential(*layers)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ convolutional_block_attention + + +

+ + +
+

+ Bases: Module

+ + +

Convolutional Block Attention Module (CBAM) class. +This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class convolutional_block_attention(torch.nn.Module):
+    """
+    Convolutional Block Attention Module (CBAM) class. 
+    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).
+    """
+    def __init__(
+                 self, 
+                 gate_channels, 
+                 reduction_ratio = 16, 
+                 pool_types = ['avg', 'max'], 
+                 no_spatial = False
+                ):
+        """
+        Initializes the convolutional block attention module.
+
+        Parameters
+        ----------
+        gate_channels   : int
+                          Number of channels of the input feature map.
+        reduction_ratio : int
+                          Reduction ratio for the channel attention.
+        pool_types      : list
+                          List of pooling operations to apply for channel attention.
+        no_spatial      : bool
+                          If True, spatial attention is not applied.
+        """
+        super(convolutional_block_attention, self).__init__()
+        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+        self.no_spatial = no_spatial
+        if not no_spatial:
+            self.spatial_gate = spatial_gate()
+
+
+    class Flatten(torch.nn.Module):
+        """
+        Flattens the input tensor to a 2D matrix.
+        """
+        def forward(self, x):
+            return x.view(x.size(0), -1)
+
+
+    def forward(self, x):
+        """
+        Forward pass of the convolutional block attention module.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the CBAM module.
+
+        Returns
+        -------
+        x_out        : torch.tensor
+                       Output tensor after applying channel and spatial attention.
+        """
+        x_out = self.channel_gate(x)
+        if not self.no_spatial:
+            x_out = self.spatial_gate(x_out)
+        return x_out
+
+
+ + + +
+ + + + + + + + +
+ + + +

+ Flatten + + +

+ + +
+

+ Bases: Module

+ + +

Flattens the input tensor to a 2D matrix.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class Flatten(torch.nn.Module):
+    """
+    Flattens the input tensor to a 2D matrix.
+    """
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ __init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False) + +

+ + +
+ +

Initializes the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + gate_channels + – +
    +
              Number of channels of the input feature map.
    +
    +
    +
  • +
  • + reduction_ratio + (int, default: + 16 +) + – +
    +
              Reduction ratio for the channel attention.
    +
    +
    +
  • +
  • + pool_types + – +
    +
              List of pooling operations to apply for channel attention.
    +
    +
    +
  • +
  • + no_spatial + – +
    +
              If True, spatial attention is not applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self, 
+             gate_channels, 
+             reduction_ratio = 16, 
+             pool_types = ['avg', 'max'], 
+             no_spatial = False
+            ):
+    """
+    Initializes the convolutional block attention module.
+
+    Parameters
+    ----------
+    gate_channels   : int
+                      Number of channels of the input feature map.
+    reduction_ratio : int
+                      Reduction ratio for the channel attention.
+    pool_types      : list
+                      List of pooling operations to apply for channel attention.
+    no_spatial      : bool
+                      If True, spatial attention is not applied.
+    """
+    super(convolutional_block_attention, self).__init__()
+    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)
+    self.no_spatial = no_spatial
+    if not no_spatial:
+        self.spatial_gate = spatial_gate()
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the convolutional block attention module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the CBAM module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +x_out ( tensor +) – +
    +

    Output tensor after applying channel and spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the convolutional block attention module.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the CBAM module.
+
+    Returns
+    -------
+    x_out        : torch.tensor
+                   Output tensor after applying channel and spatial attention.
+    """
+    x_out = self.channel_gate(x)
+    if not self.no_spatial:
+        x_out = self.spatial_gate(x_out)
+    return x_out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ double_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A double convolution layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class double_convolution(torch.nn.Module):
+    """
+    A double convolution layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = None,
+                 output_channels = 2,
+                 kernel_size = 3, 
+                 bias = False,
+                 normalization = True,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        Double convolution model.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of channels in the hidden layer between two convolutions.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        if isinstance(mid_channels, type(None)):
+            mid_channels = output_channels
+        self.activation = activation
+        self.model = torch.nn.Sequential(
+                                         convolution_layer(
+                                                           input_channels = input_channels,
+                                                           output_channels = mid_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          ),
+                                         convolution_layer(
+                                                           input_channels = mid_channels,
+                                                           output_channels = output_channels,
+                                                           kernel_size = kernel_size,
+                                                           bias = bias,
+                                                           normalization = normalization,
+                                                           activation = self.activation
+                                                          )
+                                        )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.model(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU()) + +

+ + +
+ +

Double convolution model.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of channels in the hidden layer between two convolutions.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = None,
+             output_channels = 2,
+             kernel_size = 3, 
+             bias = False,
+             normalization = True,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    Double convolution model.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of channels in the hidden layer between two convolutions.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    if isinstance(mid_channels, type(None)):
+        mid_channels = output_channels
+    self.activation = activation
+    self.model = torch.nn.Sequential(
+                                     convolution_layer(
+                                                       input_channels = input_channels,
+                                                       output_channels = mid_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      ),
+                                     convolution_layer(
+                                                       input_channels = mid_channels,
+                                                       output_channels = output_channels,
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       normalization = normalization,
+                                                       activation = self.activation
+                                                      )
+                                    )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.model(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ downsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

A downscaling component followed by a double convolution.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class downsample_layer(torch.nn.Module):
+    """
+    A downscaling component followed by a double convolution.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.maxpool_conv = torch.nn.Sequential(
+                                                torch.nn.MaxPool2d(2),
+                                                double_convolution(
+                                                                   input_channels = input_channels,
+                                                                   mid_channels = output_channels,
+                                                                   output_channels = output_channels,
+                                                                   kernel_size = kernel_size,
+                                                                   bias = bias,
+                                                                   activation = activation
+                                                                  )
+                                               )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x              : torch.tensor
+                         First input data.
+
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        result = self.maxpool_conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.maxpool_conv = torch.nn.Sequential(
+                                            torch.nn.MaxPool2d(2),
+                                            double_convolution(
+                                                               input_channels = input_channels,
+                                                               mid_channels = output_channels,
+                                                               output_channels = output_channels,
+                                                               kernel_size = kernel_size,
+                                                               bias = bias,
+                                                               activation = activation
+                                                              )
+                                           )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
             First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x              : torch.tensor
+                     First input data.
+
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    result = self.maxpool_conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ focal_surface_light_propagation + + +

+ + +
+

+ Bases: Module

+ + +

focal_surface_light_propagation model.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.

+
+ + + + + +
+ Source code in odak/learn/wave/models.py +
class focal_surface_light_propagation(torch.nn.Module):
+    """
+    focal_surface_light_propagation model.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit}. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.
+    """
+    def __init__(
+                 self,
+                 depth = 3,
+                 dimensions = 8,
+                 input_channels = 6,
+                 out_channels = 6,
+                 kernel_size = 3,
+                 bias = True,
+                 device = torch.device('cpu'),
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes the focal surface light propagation model.
+
+        Parameters
+        ----------
+        depth             : int
+                            Number of downsampling and upsampling layers.
+        dimensions        : int
+                            Number of dimensions/features in the model.
+        input_channels    : int
+                            Number of input channels.
+        out_channels      : int
+                            Number of output channels.
+        kernel_size       : int
+                            Size of the convolution kernel.
+        bias              : bool
+                            If True, allows convolutional layers to learn a bias term.
+        device            : torch.device
+                            Default device is CPU.
+        activation        : torch.nn.Module
+                            Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.device = device
+        self.sv_kernel_generation = spatially_varying_kernel_generation_model(
+            depth = depth,
+            dimensions = dimensions,
+            input_channels = input_channels + 1,  # +1 to account for an extra channel
+            kernel_size = kernel_size,
+            bias = bias,
+            activation = activation
+        )
+        self.light_propagation = spatially_adaptive_unet(
+            depth = depth,
+            dimensions = dimensions,
+            input_channels = input_channels,
+            out_channels = out_channels,
+            kernel_size = kernel_size,
+            bias = bias,
+            activation = activation
+        )
+
+
+    def forward(self, focal_surface, phase_only_hologram):
+        """
+        Forward pass through the model.
+
+        Parameters
+        ----------
+        focal_surface         : torch.Tensor
+                                Input focal surface.
+        phase_only_hologram   : torch.Tensor
+                                Input phase-only hologram.
+
+        Returns
+        ----------
+        result                : torch.Tensor
+                                Output tensor after light propagation.
+        """
+        input_field = self.generate_input_field(phase_only_hologram)
+        sv_kernel = self.sv_kernel_generation(focal_surface, input_field)
+        output_field = self.light_propagation(sv_kernel, input_field)
+        final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])
+        result = calculate_amplitude(final) ** 2
+        return result
+
+
+    def generate_input_field(self, phase_only_hologram):
+        """
+        Generates an input field by combining the real and imaginary parts.
+
+        Parameters
+        ----------
+        phase_only_hologram   : torch.Tensor
+                                Input phase-only hologram.
+
+        Returns
+        ----------
+        input_field           : torch.Tensor
+                                Concatenated real and imaginary parts of the complex field.
+        """
+        [b, c, h, w] = phase_only_hologram.size()
+        input_phase = phase_only_hologram * 2 * np.pi
+        hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)
+        field = generate_complex_field(hologram_amplitude, input_phase)
+        input_field = torch.cat((field.real, field.imag), dim = 1)
+        return input_field
+
+
+    def load_weights(self, weight_filename, key_mapping_filename):
+        """
+        Function to load weights for this multi-layer perceptron from a file.
+
+        Parameters
+        ----------
+        weight_filename      : str
+                               Path to the old model's weight file.
+        key_mapping_filename : str
+                               Path to the JSON file containing the key mappings.
+        """
+        # Load old model weights
+        old_model_weights = torch.load(weight_filename, map_location = self.device)
+
+        # Load key mappings from JSON file
+        with open(key_mapping_filename, 'r') as json_file:
+            key_mappings = json.load(json_file)
+
+        # Extract the key mappings for sv_kernel_generation and light_prop
+        sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']
+        light_prop_key_mapping = key_mappings['light_prop_key_mapping']
+
+        # Initialize new state dicts
+        sv_kernel_generation_new_state_dict = {}
+        light_prop_new_state_dict = {}
+
+        # Map and load sv_kernel_generation_model weights
+        for old_key, value in old_model_weights.items():
+            if old_key in sv_kernel_generation_key_mapping:
+                # Map the old key to the new key
+                new_key = sv_kernel_generation_key_mapping[old_key]
+                sv_kernel_generation_new_state_dict[new_key] = value
+
+        self.sv_kernel_generation.to(self.device)
+        self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)
+
+        # Map and load light_prop model weights
+        for old_key, value in old_model_weights.items():
+            if old_key in light_prop_key_mapping:
+                # Map the old key to the new key
+                new_key = light_prop_key_mapping[old_key]
+                light_prop_new_state_dict[new_key] = value
+        self.light_propagation.to(self.device)
+        self.light_propagation.load_state_dict(light_prop_new_state_dict)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, device=torch.device('cpu'), activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes the focal surface light propagation model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
                Number of downsampling and upsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
                Number of dimensions/features in the model.
    +
    +
    +
  • +
  • + input_channels + – +
    +
                Number of input channels.
    +
    +
    +
  • +
  • + out_channels + – +
    +
                Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
                Size of the convolution kernel.
    +
    +
    +
  • +
  • + bias + – +
    +
                If True, allows convolutional layers to learn a bias term.
    +
    +
    +
  • +
  • + device + – +
    +
                Default device is CPU.
    +
    +
    +
  • +
  • + activation + – +
    +
                Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def __init__(
+             self,
+             depth = 3,
+             dimensions = 8,
+             input_channels = 6,
+             out_channels = 6,
+             kernel_size = 3,
+             bias = True,
+             device = torch.device('cpu'),
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes the focal surface light propagation model.
+
+    Parameters
+    ----------
+    depth             : int
+                        Number of downsampling and upsampling layers.
+    dimensions        : int
+                        Number of dimensions/features in the model.
+    input_channels    : int
+                        Number of input channels.
+    out_channels      : int
+                        Number of output channels.
+    kernel_size       : int
+                        Size of the convolution kernel.
+    bias              : bool
+                        If True, allows convolutional layers to learn a bias term.
+    device            : torch.device
+                        Default device is CPU.
+    activation        : torch.nn.Module
+                        Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.device = device
+    self.sv_kernel_generation = spatially_varying_kernel_generation_model(
+        depth = depth,
+        dimensions = dimensions,
+        input_channels = input_channels + 1,  # +1 to account for an extra channel
+        kernel_size = kernel_size,
+        bias = bias,
+        activation = activation
+    )
+    self.light_propagation = spatially_adaptive_unet(
+        depth = depth,
+        dimensions = dimensions,
+        input_channels = input_channels,
+        out_channels = out_channels,
+        kernel_size = kernel_size,
+        bias = bias,
+        activation = activation
+    )
+
+
+
+ +
+ +
+ + +

+ forward(focal_surface, phase_only_hologram) + +

+ + +
+ +

Forward pass through the model.

+ + +

Parameters:

+
    +
  • + focal_surface + – +
    +
                    Input focal surface.
    +
    +
    +
  • +
  • + phase_only_hologram + – +
    +
                    Input phase-only hologram.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( Tensor +) – +
    +

    Output tensor after light propagation.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def forward(self, focal_surface, phase_only_hologram):
+    """
+    Forward pass through the model.
+
+    Parameters
+    ----------
+    focal_surface         : torch.Tensor
+                            Input focal surface.
+    phase_only_hologram   : torch.Tensor
+                            Input phase-only hologram.
+
+    Returns
+    ----------
+    result                : torch.Tensor
+                            Output tensor after light propagation.
+    """
+    input_field = self.generate_input_field(phase_only_hologram)
+    sv_kernel = self.sv_kernel_generation(focal_surface, input_field)
+    output_field = self.light_propagation(sv_kernel, input_field)
+    final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])
+    result = calculate_amplitude(final) ** 2
+    return result
+
+
+
+ +
+ +
+ + +

+ generate_input_field(phase_only_hologram) + +

+ + +
+ +

Generates an input field by combining the real and imaginary parts.

+ + +

Parameters:

+
    +
  • + phase_only_hologram + – +
    +
                    Input phase-only hologram.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +input_field ( Tensor +) – +
    +

    Concatenated real and imaginary parts of the complex field.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def generate_input_field(self, phase_only_hologram):
+    """
+    Generates an input field by combining the real and imaginary parts.
+
+    Parameters
+    ----------
+    phase_only_hologram   : torch.Tensor
+                            Input phase-only hologram.
+
+    Returns
+    ----------
+    input_field           : torch.Tensor
+                            Concatenated real and imaginary parts of the complex field.
+    """
+    [b, c, h, w] = phase_only_hologram.size()
+    input_phase = phase_only_hologram * 2 * np.pi
+    hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)
+    field = generate_complex_field(hologram_amplitude, input_phase)
+    input_field = torch.cat((field.real, field.imag), dim = 1)
+    return input_field
+
+
+
+ +
+ +
+ + +

+ load_weights(weight_filename, key_mapping_filename) + +

+ + +
+ +

Function to load weights for this multi-layer perceptron from a file.

+ + +

Parameters:

+
    +
  • + weight_filename + – +
    +
                   Path to the old model's weight file.
    +
    +
    +
  • +
  • + key_mapping_filename + (str) + – +
    +
                   Path to the JSON file containing the key mappings.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def load_weights(self, weight_filename, key_mapping_filename):
+    """
+    Function to load weights for this multi-layer perceptron from a file.
+
+    Parameters
+    ----------
+    weight_filename      : str
+                           Path to the old model's weight file.
+    key_mapping_filename : str
+                           Path to the JSON file containing the key mappings.
+    """
+    # Load old model weights
+    old_model_weights = torch.load(weight_filename, map_location = self.device)
+
+    # Load key mappings from JSON file
+    with open(key_mapping_filename, 'r') as json_file:
+        key_mappings = json.load(json_file)
+
+    # Extract the key mappings for sv_kernel_generation and light_prop
+    sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']
+    light_prop_key_mapping = key_mappings['light_prop_key_mapping']
+
+    # Initialize new state dicts
+    sv_kernel_generation_new_state_dict = {}
+    light_prop_new_state_dict = {}
+
+    # Map and load sv_kernel_generation_model weights
+    for old_key, value in old_model_weights.items():
+        if old_key in sv_kernel_generation_key_mapping:
+            # Map the old key to the new key
+            new_key = sv_kernel_generation_key_mapping[old_key]
+            sv_kernel_generation_new_state_dict[new_key] = value
+
+    self.sv_kernel_generation.to(self.device)
+    self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)
+
+    # Map and load light_prop model weights
+    for old_key, value in old_model_weights.items():
+        if old_key in light_prop_key_mapping:
+            # Map the old key to the new key
+            new_key = light_prop_key_mapping[old_key]
+            light_prop_new_state_dict[new_key] = value
+    self.light_propagation.to(self.device)
+    self.light_propagation.load_state_dict(light_prop_new_state_dict)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_feature_module + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies them to another input tensor via learned transformations.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_feature_module(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies them to another input tensor via learned transformations.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 mid_channels,
+                 output_channels,
+                 kernel_size,
+                 bias = False,
+                 normalization = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels  : int
+                          Number of mid channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        normalization   : bool
+                          If True, adds a Batch Normalization layer after the convolutional layer.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.transformations_1 = global_transformations(input_channels, output_channels)
+        self.global_features_1 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.global_features_2 = double_convolution(
+                                                    input_channels = input_channels,
+                                                    mid_channels = mid_channels,
+                                                    output_channels = output_channels,
+                                                    kernel_size = kernel_size,
+                                                    bias = bias,
+                                                    normalization = normalization,
+                                                    activation = activation
+                                                   )
+        self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        global_tensor_1 = self.transformations_1(x1, x2)
+        y1 = self.global_features_1(global_tensor_1)
+        y2 = self.global_features_2(y1)
+        global_tensor_2 = self.transformations_2(y1, y2)
+        return global_tensor_2
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of mid channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
              If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             mid_channels,
+             output_channels,
+             kernel_size,
+             bias = False,
+             normalization = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels  : int
+                      Number of mid channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    normalization   : bool
+                      If True, adds a Batch Normalization layer after the convolutional layer.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.transformations_1 = global_transformations(input_channels, output_channels)
+    self.global_features_1 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.global_features_2 = double_convolution(
+                                                input_channels = input_channels,
+                                                mid_channels = mid_channels,
+                                                output_channels = output_channels,
+                                                kernel_size = kernel_size,
+                                                bias = bias,
+                                                normalization = normalization,
+                                                activation = activation
+                                               )
+    self.transformations_2 = global_transformations(input_channels, output_channels)
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    global_tensor_1 = self.transformations_1(x1, x2)
+    y1 = self.global_features_1(global_tensor_1)
+    y2 = self.global_features_2(y1)
+    global_tensor_2 = self.transformations_2(y1, y2)
+    return global_tensor_2
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ global_transformations + + +

+ + +
+

+ Bases: Module

+ + +

A global feature layer that processes global features from input channels and +applies learned transformations to another input tensor.

+

This implementation is adapted from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Reference: +J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class global_transformations(torch.nn.Module):
+    """
+    A global feature layer that processes global features from input channels and
+    applies learned transformations to another input tensor.
+
+    This implementation is adapted from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Reference:
+    J. Huang, P. Zhu, M. Geng et al. "Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices."
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels
+                ):
+        """
+        A global feature layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        """
+        super().__init__()
+        self.global_feature_1 = torch.nn.Sequential(
+            torch.nn.Linear(input_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True),
+        )
+        self.global_feature_2 = torch.nn.Sequential(
+            torch.nn.Linear(output_channels, output_channels),
+            torch.nn.LeakyReLU(0.2, inplace = True)
+        )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.
+        """
+        y = torch.mean(x2, dim = (2, 3))
+        y1 = self.global_feature_1(y)
+        y2 = self.global_feature_2(y1)
+        y1 = y1.unsqueeze(2).unsqueeze(3)
+        y2 = y2.unsqueeze(2).unsqueeze(3)
+        result = x1 * y1 + y2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels) + +

+ + +
+ +

A global feature layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels
+            ):
+    """
+    A global feature layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    """
+    super().__init__()
+    self.global_feature_1 = torch.nn.Sequential(
+        torch.nn.Linear(input_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True),
+    )
+    self.global_feature_2 = torch.nn.Sequential(
+        torch.nn.Linear(output_channels, output_channels),
+        torch.nn.LeakyReLU(0.2, inplace = True)
+    )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.
+    """
+    y = torch.mean(x2, dim = (2, 3))
+    y1 = self.global_feature_1(y)
+    y2 = self.global_feature_2(y1)
+    y1 = y1.unsqueeze(2).unsqueeze(3)
+    y2 = y2.unsqueeze(2).unsqueeze(3)
+    result = x1 * y1 + y2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ holobeam_multiholo + + +

+ + +
+

+ Bases: Module

+ + +

The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.

+ + +

Parameters:

+
    +
  • + n_input + – +
    +
                Number of channels in the input.
    +
    +
    +
  • +
  • + n_hidden + – +
    +
                Number of channels in the hidden layers.
    +
    +
    +
  • +
  • + n_output + – +
    +
                Number of channels in the output layer.
    +
    +
    +
  • +
  • + device + – +
    +
                Default device is CPU.
    +
    +
    +
  • +
  • + reduction + – +
    +
                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
    +
    +
    +
  • +
+ + + + + + +
+ Source code in odak/learn/wave/models.py +
class holobeam_multiholo(torch.nn.Module):
+    """
+    The learned holography model used in the paper, Akşit, Kaan, and Yuta Itoh. "HoloBeam: Paper-Thin Near-Eye Displays." In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.
+
+
+    Parameters
+    ----------
+    n_input           : int
+                        Number of channels in the input.
+    n_hidden          : int
+                        Number of channels in the hidden layers.
+    n_output          : int
+                        Number of channels in the output layer.
+    device            : torch.device
+                        Default device is CPU.
+    reduction         : str
+                        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.
+    """
+    def __init__(
+                 self,
+                 n_input = 1,
+                 n_hidden = 16,
+                 n_output = 2,
+                 device = torch.device('cpu'),
+                 reduction = 'sum'
+                ):
+        super(holobeam_multiholo, self).__init__()
+        torch.random.seed()
+        self.device = device
+        self.reduction = reduction
+        self.l2 = torch.nn.MSELoss(reduction = self.reduction)
+        self.l1 = torch.nn.L1Loss(reduction = self.reduction)
+        self.n_input = n_input
+        self.n_hidden = n_hidden
+        self.n_output = n_output
+        self.network = unet(
+                            dimensions = self.n_hidden,
+                            input_channels = self.n_input,
+                            output_channels = self.n_output
+                           ).to(self.device)
+
+
+    def forward(self, x, test = False):
+        """
+        Internal function representing the forward model.
+        """
+        if test:
+            torch.no_grad()
+        y = self.network.forward(x) 
+        phase_low = y[:, 0].unsqueeze(1)
+        phase_high = y[:, 1].unsqueeze(1)
+        phase_only = torch.zeros_like(phase_low)
+        phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
+        phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
+        phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
+        phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
+        return phase_only
+
+
+    def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
+        """
+        Internal function for evaluating.
+        """
+        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
+        return loss
+
+
+    def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
+        """
+        Function to train the weights of the multi layer perceptron.
+
+        Parameters
+        ----------
+        dataloader       : torch.utils.data.DataLoader
+                           Data loader.
+        number_of_epochs : int
+                           Number of epochs.
+        learning_rate    : float
+                           Learning rate of the optimizer.
+        directory        : str
+                           Output directory.
+        save_at_every    : int
+                           Save the model at every given epoch count.
+        """
+        t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
+        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
+        for i in t_epoch:
+            epoch_loss = 0.
+            t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
+            for j, data in enumerate(t_data):
+                self.optimizer.zero_grad()
+                images, holograms = data
+                estimates = self.forward(images)
+                loss = self.evaluate(estimates, holograms)
+                loss.backward(retain_graph=True)
+                self.optimizer.step()
+                description = 'Loss:{:.4f}'.format(loss.item())
+                t_data.set_description(description)
+                epoch_loss += float(loss.item()) / dataloader.__len__()
+            description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
+            t_epoch.set_description(description)
+            if i % save_at_every == 0:
+                self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
+        self.save_weights(filename='{}/weights.pt'.format(directory))
+        print(description)
+
+
+    def save_weights(self, filename = './weights.pt'):
+        """
+        Function to save the current weights of the multi layer perceptron to a file.
+        Parameters
+        ----------
+        filename        : str
+                          Filename.
+        """
+        torch.save(self.network.state_dict(), os.path.expanduser(filename))
+
+
+    def load_weights(self, filename = './weights.pt'):
+        """
+        Function to load weights for this multi layer perceptron from a file.
+        Parameters
+        ----------
+        filename        : str
+                          Filename.
+        """
+        self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
+        self.network.eval()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ evaluate(input_data, ground_truth, weights=[1.0, 0.1]) + +

+ + +
+ +

Internal function for evaluating.

+ +
+ Source code in odak/learn/wave/models.py +
69
+70
+71
+72
+73
+74
def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):
+    """
+    Internal function for evaluating.
+    """
+    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)
+    return loss
+
+
+
+ +
+ +
+ + +

+ fit(dataloader, number_of_epochs=100, learning_rate=1e-05, directory='./output', save_at_every=100) + +

+ + +
+ +

Function to train the weights of the multi layer perceptron.

+ + +

Parameters:

+
    +
  • + dataloader + – +
    +
               Data loader.
    +
    +
    +
  • +
  • + number_of_epochs + (int, default: + 100 +) + – +
    +
               Number of epochs.
    +
    +
    +
  • +
  • + learning_rate + – +
    +
               Learning rate of the optimizer.
    +
    +
    +
  • +
  • + directory + – +
    +
               Output directory.
    +
    +
    +
  • +
  • + save_at_every + – +
    +
               Save the model at every given epoch count.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):
+    """
+    Function to train the weights of the multi layer perceptron.
+
+    Parameters
+    ----------
+    dataloader       : torch.utils.data.DataLoader
+                       Data loader.
+    number_of_epochs : int
+                       Number of epochs.
+    learning_rate    : float
+                       Learning rate of the optimizer.
+    directory        : str
+                       Output directory.
+    save_at_every    : int
+                       Save the model at every given epoch count.
+    """
+    t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)
+    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
+    for i in t_epoch:
+        epoch_loss = 0.
+        t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)
+        for j, data in enumerate(t_data):
+            self.optimizer.zero_grad()
+            images, holograms = data
+            estimates = self.forward(images)
+            loss = self.evaluate(estimates, holograms)
+            loss.backward(retain_graph=True)
+            self.optimizer.step()
+            description = 'Loss:{:.4f}'.format(loss.item())
+            t_data.set_description(description)
+            epoch_loss += float(loss.item()) / dataloader.__len__()
+        description = 'Epoch Loss:{:.4f}'.format(epoch_loss)
+        t_epoch.set_description(description)
+        if i % save_at_every == 0:
+            self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))
+    self.save_weights(filename='{}/weights.pt'.format(directory))
+    print(description)
+
+
+
+ +
+ +
+ + +

+ forward(x, test=False) + +

+ + +
+ +

Internal function representing the forward model.

+ +
+ Source code in odak/learn/wave/models.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
def forward(self, x, test = False):
+    """
+    Internal function representing the forward model.
+    """
+    if test:
+        torch.no_grad()
+    y = self.network.forward(x) 
+    phase_low = y[:, 0].unsqueeze(1)
+    phase_high = y[:, 1].unsqueeze(1)
+    phase_only = torch.zeros_like(phase_low)
+    phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]
+    phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]
+    phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]
+    phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]
+    return phase_only
+
+
+
+ +
+ +
+ + +

+ load_weights(filename='./weights.pt') + +

+ + +
+ +

Function to load weights for this multi layer perceptron from a file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
              Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def load_weights(self, filename = './weights.pt'):
+    """
+    Function to load weights for this multi layer perceptron from a file.
+    Parameters
+    ----------
+    filename        : str
+                      Filename.
+    """
+    self.network.load_state_dict(torch.load(os.path.expanduser(filename)))
+    self.network.eval()
+
+
+
+ +
+ +
+ + +

+ save_weights(filename='./weights.pt') + +

+ + +
+ +

Function to save the current weights of the multi layer perceptron to a file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
              Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/models.py +
def save_weights(self, filename = './weights.pt'):
+    """
+    Function to save the current weights of the multi layer perceptron to a file.
+    Parameters
+    ----------
+    filename        : str
+                      Filename.
+    """
+    torch.save(self.network.state_dict(), os.path.expanduser(filename))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ multi_layer_perceptron + + +

+ + +
+

+ Bases: Module

+ + +

A multi-layer perceptron model.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
class multi_layer_perceptron(torch.nn.Module):
+    """
+    A multi-layer perceptron model.
+    """
+
+    def __init__(self,
+                 dimensions,
+                 activation = torch.nn.ReLU(),
+                 bias = False,
+                 model_type = 'conventional',
+                 siren_multiplier = 1.,
+                 input_multiplier = None
+                ):
+        """
+        Parameters
+        ----------
+        dimensions        : list
+                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+        activation        : torch.nn
+                            Nonlinear activation function.
+                            Default is `torch.nn.ReLU()`.
+        bias              : bool
+                            If set to True, linear layers will include biases.
+        siren_multiplier  : float
+                            When using `SIREN` model type, this parameter functions as a hyperparameter.
+                            The original SIREN work uses 30.
+                            You can bypass this parameter by providing input that are not normalized and larger then one.
+        input_multiplier  : float
+                            Initial value of the input multiplier before the very first layer.
+        model_type        : str
+                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                            `conventional` refers to a standard multi layer perceptron.
+                            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+        """
+        super(multi_layer_perceptron, self).__init__()
+        self.activation = activation
+        self.bias = bias
+        self.model_type = model_type
+        self.layers = torch.nn.ModuleList()
+        self.siren_multiplier = siren_multiplier
+        self.dimensions = dimensions
+        for i in range(len(self.dimensions) - 1):
+            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+        if not isinstance(input_multiplier, type(None)):
+            self.input_multiplier = torch.nn.ParameterList()
+            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+        if self.model_type == 'FILM SIREN':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+        if self.model_type == 'Gaussian':
+            self.alpha = torch.nn.ParameterList()
+            for j in self.dimensions[1:-1]:
+                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        if hasattr(self, 'input_multiplier'):
+            result = x * self.input_multiplier[0]
+        else:
+            result = x
+        for layer_id, layer in enumerate(self.layers[:-1]):
+            result = layer(result)
+            if self.model_type == 'conventional':
+                result = self.activation(result)
+            elif self.model_type == 'swish':
+                resutl = swish(result)
+            elif self.model_type == 'SIREN':
+                result = torch.sin(result * self.siren_multiplier)
+            elif self.model_type == 'FILM SIREN':
+                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+            elif self.model_type == 'Gaussian': 
+                result = gaussian(result, self.alpha[layer_id][0])
+        result = self.layers[-1](result)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None) + +

+ + +
+ + + +

Parameters:

+
    +
  • + dimensions + – +
    +
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
    +
    +
    +
  • +
  • + activation + – +
    +
                Nonlinear activation function.
    +            Default is `torch.nn.ReLU()`.
    +
    +
    +
  • +
  • + bias + – +
    +
                If set to True, linear layers will include biases.
    +
    +
    +
  • +
  • + siren_multiplier + – +
    +
                When using `SIREN` model type, this parameter functions as a hyperparameter.
    +            The original SIREN work uses 30.
    +            You can bypass this parameter by providing input that are not normalized and larger then one.
    +
    +
    +
  • +
  • + input_multiplier + – +
    +
                Initial value of the input multiplier before the very first layer.
    +
    +
    +
  • +
  • + model_type + – +
    +
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
    +            `conventional` refers to a standard multi layer perceptron.
    +            For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
    +            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
    +            For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
    +            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
def __init__(self,
+             dimensions,
+             activation = torch.nn.ReLU(),
+             bias = False,
+             model_type = 'conventional',
+             siren_multiplier = 1.,
+             input_multiplier = None
+            ):
+    """
+    Parameters
+    ----------
+    dimensions        : list
+                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).
+    activation        : torch.nn
+                        Nonlinear activation function.
+                        Default is `torch.nn.ReLU()`.
+    bias              : bool
+                        If set to True, linear layers will include biases.
+    siren_multiplier  : float
+                        When using `SIREN` model type, this parameter functions as a hyperparameter.
+                        The original SIREN work uses 30.
+                        You can bypass this parameter by providing input that are not normalized and larger then one.
+    input_multiplier  : float
+                        Initial value of the input multiplier before the very first layer.
+    model_type        : str
+                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.
+                        `conventional` refers to a standard multi layer perceptron.
+                        For `SIREN,` see: Sitzmann, Vincent, et al. "Implicit neural representations with periodic activation functions." Advances in neural information processing systems 33 (2020): 7462-7473.
+                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. "Searching for activation functions." arXiv preprint arXiv:1710.05941 (2017). 
+                        For `FILM SIREN,` see: Chan, Eric R., et al. "pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
+                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+    """
+    super(multi_layer_perceptron, self).__init__()
+    self.activation = activation
+    self.bias = bias
+    self.model_type = model_type
+    self.layers = torch.nn.ModuleList()
+    self.siren_multiplier = siren_multiplier
+    self.dimensions = dimensions
+    for i in range(len(self.dimensions) - 1):
+        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))
+    if not isinstance(input_multiplier, type(None)):
+        self.input_multiplier = torch.nn.ParameterList()
+        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))
+    if self.model_type == 'FILM SIREN':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))
+    if self.model_type == 'Gaussian':
+        self.alpha = torch.nn.ParameterList()
+        for j in self.dimensions[1:-1]:
+            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    if hasattr(self, 'input_multiplier'):
+        result = x * self.input_multiplier[0]
+    else:
+        result = x
+    for layer_id, layer in enumerate(self.layers[:-1]):
+        result = layer(result)
+        if self.model_type == 'conventional':
+            result = self.activation(result)
+        elif self.model_type == 'swish':
+            resutl = swish(result)
+        elif self.model_type == 'SIREN':
+            result = torch.sin(result * self.siren_multiplier)
+        elif self.model_type == 'FILM SIREN':
+            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])
+        elif self.model_type == 'Gaussian': 
+            result = gaussian(result, self.alpha[layer_id][0])
+    result = self.layers[-1](result)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ non_local_layer + + +

+ + +
+

+ Bases: Module

+ + +

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class non_local_layer(torch.nn.Module):
+    """
+    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)
+    """
+    def __init__(
+                 self,
+                 input_channels = 1024,
+                 bottleneck_channels = 512,
+                 kernel_size = 1,
+                 bias = False,
+                ):
+        """
+
+        Parameters
+        ----------
+        input_channels      : int
+                              Number of input channels.
+        bottleneck_channels : int
+                              Number of middle channels.
+        kernel_size         : int
+                              Kernel size.
+        bias                : bool 
+                              Set to True to let convolutional layers have bias term.
+        """
+        super(non_local_layer, self).__init__()
+        self.input_channels = input_channels
+        self.bottleneck_channels = bottleneck_channels
+        self.g = torch.nn.Conv2d(
+                                 self.input_channels, 
+                                 self.bottleneck_channels,
+                                 kernel_size = kernel_size,
+                                 padding = kernel_size // 2,
+                                 bias = bias
+                                )
+        self.W_z = torch.nn.Sequential(
+                                       torch.nn.Conv2d(
+                                                       self.bottleneck_channels,
+                                                       self.input_channels, 
+                                                       kernel_size = kernel_size,
+                                                       bias = bias,
+                                                       padding = kernel_size // 2
+                                                      ),
+                                       torch.nn.BatchNorm2d(self.input_channels)
+                                      )
+        torch.nn.init.constant_(self.W_z[1].weight, 0)   
+        torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+    def forward(self, x):
+        """
+        Forward model [zi = Wzyi + xi]
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          First input data.                       
+
+
+        Returns
+        ----------
+        z               : torch.tensor
+                          Estimated output.
+        """
+        batch_size, channels, height, width = x.size()
+        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+        attn = torch.nn.functional.softmax(attn, dim=-1)
+        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+        W_y = self.W_z(y)
+        z = W_y + x
+        return z
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False) + +

+ + +
+ + + +

Parameters:

+
    +
  • + input_channels + – +
    +
                  Number of input channels.
    +
    +
    +
  • +
  • + bottleneck_channels + (int, default: + 512 +) + – +
    +
                  Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
                  Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
                  Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 1024,
+             bottleneck_channels = 512,
+             kernel_size = 1,
+             bias = False,
+            ):
+    """
+
+    Parameters
+    ----------
+    input_channels      : int
+                          Number of input channels.
+    bottleneck_channels : int
+                          Number of middle channels.
+    kernel_size         : int
+                          Kernel size.
+    bias                : bool 
+                          Set to True to let convolutional layers have bias term.
+    """
+    super(non_local_layer, self).__init__()
+    self.input_channels = input_channels
+    self.bottleneck_channels = bottleneck_channels
+    self.g = torch.nn.Conv2d(
+                             self.input_channels, 
+                             self.bottleneck_channels,
+                             kernel_size = kernel_size,
+                             padding = kernel_size // 2,
+                             bias = bias
+                            )
+    self.W_z = torch.nn.Sequential(
+                                   torch.nn.Conv2d(
+                                                   self.bottleneck_channels,
+                                                   self.input_channels, 
+                                                   kernel_size = kernel_size,
+                                                   bias = bias,
+                                                   padding = kernel_size // 2
+                                                  ),
+                                   torch.nn.BatchNorm2d(self.input_channels)
+                                  )
+    torch.nn.init.constant_(self.W_z[1].weight, 0)   
+    torch.nn.init.constant_(self.W_z[1].bias, 0)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model [zi = Wzyi + xi]

+ + +

Parameters:

+
    +
  • + x + – +
    +
              First input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +z ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model [zi = Wzyi + xi]
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      First input data.                       
+
+
+    Returns
+    ----------
+    z               : torch.tensor
+                      Estimated output.
+    """
+    batch_size, channels, height, width = x.size()
+    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)
+    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)
+    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)
+    attn = torch.nn.functional.softmax(attn, dim=-1)
+    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)
+    W_y = self.W_z(y)
+    z = W_y + x
+    return z
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ normalization + + +

+ + +
+

+ Bases: Module

+ + +

A normalization layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class normalization(torch.nn.Module):
+    """
+    A normalization layer.
+    """
+    def __init__(
+                 self,
+                 dim = 1,
+                ):
+        """
+        Normalization layer.
+
+
+        Parameters
+        ----------
+        dim             : int
+                          Dimension (axis) to normalize.
+        """
+        super().__init__()
+        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+        mean = torch.mean(x, dim = 1, keepdim = True)
+        result =  (x - mean) * (var + eps).rsqrt() * self.k
+        return result 
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dim=1) + +

+ + +
+ +

Normalization layer.

+ + +

Parameters:

+
    +
  • + dim + – +
    +
              Dimension (axis) to normalize.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             dim = 1,
+            ):
+    """
+    Normalization layer.
+
+
+    Parameters
+    ----------
+    dim             : int
+                      Dimension (axis) to normalize.
+    """
+    super().__init__()
+    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    eps = 1e-5 if x.dtype == torch.float32 else 1e-3
+    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
+    mean = torch.mean(x, dim = 1, keepdim = True)
+    result =  (x - mean) * (var + eps).rsqrt() * self.k
+    return result 
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ positional_encoder + + +

+ + +
+

+ Bases: Module

+ + +

A positional encoder module.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class positional_encoder(torch.nn.Module):
+    """
+    A positional encoder module.
+    """
+
+    def __init__(self, L):
+        """
+        A positional encoder module.
+
+        Parameters
+        ----------
+        L                   : int
+                              Positional encoding level.
+        """
+        super(positional_encoder, self).__init__()
+        self.L = L
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x               : torch.tensor
+                          Input data.
+
+        Returns
+        ----------
+        result          : torch.tensor
+                          Result of the forward operation
+        """
+        B, C = x.shape
+        x = x.view(B, C, 1)
+        results = [x]
+        for i in range(1, self.L + 1):
+            freq = (2 ** i) * math.pi
+            cos_x = torch.cos(freq * x)
+            sin_x = torch.sin(freq * x)
+            results.append(cos_x)
+            results.append(sin_x)
+        results = torch.cat(results, dim=2)
+        results = results.permute(0, 2, 1)
+        results = results.reshape(B, -1)
+        return results
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(L) + +

+ + +
+ +

A positional encoder module.

+ + +

Parameters:

+
    +
  • + L + – +
    +
                  Positional encoding level.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(self, L):
+    """
+    A positional encoder module.
+
+    Parameters
+    ----------
+    L                   : int
+                          Positional encoding level.
+    """
+    super(positional_encoder, self).__init__()
+    self.L = L
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
              Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x               : torch.tensor
+                      Input data.
+
+    Returns
+    ----------
+    result          : torch.tensor
+                      Result of the forward operation
+    """
+    B, C = x.shape
+    x = x.view(B, C, 1)
+    results = [x]
+    for i in range(1, self.L + 1):
+        freq = (2 ** i) * math.pi
+        cos_x = torch.cos(freq * x)
+        sin_x = torch.sin(freq * x)
+        results.append(cos_x)
+        results.append(sin_x)
+    results = torch.cat(results, dim=2)
+    results = results.permute(0, 2, 1)
+    results = results.reshape(B, -1)
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_attention_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual block with an attention layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_attention_layer(torch.nn.Module):
+    """
+    A residual block with an attention layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 1,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        An attention layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int or optioal
+                          Number of input channels.
+        output_channels : int or optional
+                          Number of middle channels.
+        kernel_size     : int or optional
+                          Kernel size.
+        bias            : bool or optional
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn or optional
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution0 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.convolution1 = torch.nn.Sequential(
+                                                torch.nn.Conv2d(
+                                                                input_channels,
+                                                                output_channels,
+                                                                kernel_size = kernel_size,
+                                                                padding = kernel_size // 2,
+                                                                bias = bias
+                                                               ),
+                                                torch.nn.BatchNorm2d(output_channels)
+                                               )
+        self.final_layer = torch.nn.Sequential(
+                                               self.activation,
+                                               torch.nn.Conv2d(
+                                                               output_channels,
+                                                               output_channels,
+                                                               kernel_size = kernel_size,
+                                                               padding = kernel_size // 2,
+                                                               bias = bias
+                                                              )
+                                              )
+
+
+    def forward(self, x0, x1):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x0             : torch.tensor
+                         First input data.
+
+        x1             : torch.tensor
+                         Seconnd input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        y0 = self.convolution0(x0)
+        y1 = self.convolution1(x1)
+        y2 = torch.add(y0, y1)
+        result = self.final_layer(y2) * x0
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

An attention layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int or optional, default: + 2 +) + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 1,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    An attention layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int or optioal
+                      Number of input channels.
+    output_channels : int or optional
+                      Number of middle channels.
+    kernel_size     : int or optional
+                      Kernel size.
+    bias            : bool or optional
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn or optional
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution0 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.convolution1 = torch.nn.Sequential(
+                                            torch.nn.Conv2d(
+                                                            input_channels,
+                                                            output_channels,
+                                                            kernel_size = kernel_size,
+                                                            padding = kernel_size // 2,
+                                                            bias = bias
+                                                           ),
+                                            torch.nn.BatchNorm2d(output_channels)
+                                           )
+    self.final_layer = torch.nn.Sequential(
+                                           self.activation,
+                                           torch.nn.Conv2d(
+                                                           output_channels,
+                                                           output_channels,
+                                                           kernel_size = kernel_size,
+                                                           padding = kernel_size // 2,
+                                                           bias = bias
+                                                          )
+                                          )
+
+
+
+ +
+ +
+ + +

+ forward(x0, x1) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x0 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x1 + – +
    +
             Seconnd input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x0, x1):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x0             : torch.tensor
+                     First input data.
+
+    x1             : torch.tensor
+                     Seconnd input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    y0 = self.convolution0(x0)
+    y1 = self.convolution1(x1)
+    y2 = torch.add(y0, y1)
+    result = self.final_layer(y2) * x0
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ residual_layer + + +

+ + +
+

+ Bases: Module

+ + +

A residual layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class residual_layer(torch.nn.Module):
+    """
+    A residual layer.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 mid_channels = 16,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU()
+                ):
+        """
+        A convolutional layer class.
+
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        mid_channels    : int
+                          Number of middle channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super().__init__()
+        self.activation = activation
+        self.convolution = double_convolution(
+                                              input_channels,
+                                              mid_channels = mid_channels,
+                                              output_channels = input_channels,
+                                              kernel_size = kernel_size,
+                                              bias = bias,
+                                              activation = activation
+                                             )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        x0 = self.convolution(x)
+        return x + x0
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU()) + +

+ + +
+ +

A convolutional layer class.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + mid_channels + – +
    +
              Number of middle channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def __init__(
+             self,
+             input_channels = 2,
+             mid_channels = 16,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU()
+            ):
+    """
+    A convolutional layer class.
+
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    mid_channels    : int
+                      Number of middle channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super().__init__()
+    self.activation = activation
+    self.convolution = double_convolution(
+                                          input_channels,
+                                          mid_channels = mid_channels,
+                                          output_channels = input_channels,
+                                          kernel_size = kernel_size,
+                                          bias = bias,
+                                          activation = activation
+                                         )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    x0 = self.convolution(x)
+    return x + x0
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatial_gate + + +

+ + +
+

+ Bases: Module

+ + +

Spatial attention module that applies a convolution layer after channel pooling. +This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class spatial_gate(torch.nn.Module):
+    """
+    Spatial attention module that applies a convolution layer after channel pooling.
+    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.
+    """
+    def __init__(self):
+        """
+        Initializes the spatial gate module.
+        """
+        super().__init__()
+        kernel_size = 7
+        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+    def channel_pool(self, x):
+        """
+        Applies max and average pooling on the channels.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input tensor.
+
+        Returns
+        -------
+        output        : torch.tensor
+                        Output tensor.
+        """
+        max_pool = torch.max(x, 1)[0].unsqueeze(1)
+        avg_pool = torch.mean(x, 1).unsqueeze(1)
+        output = torch.cat((max_pool, avg_pool), dim=1)
+        return output
+
+
+    def forward(self, x):
+        """
+        Forward pass of the SpatialGate module.
+
+        Applies spatial attention to the input tensor.
+
+        Parameters
+        ----------
+        x            : torch.tensor
+                       Input tensor to the SpatialGate module.
+
+        Returns
+        -------
+        scaled_x     : torch.tensor
+                       Output tensor after applying spatial attention.
+        """
+        x_compress = self.channel_pool(x)
+        x_out = self.spatial(x_compress)
+        scale = torch.sigmoid(x_out)
+        scaled_x = x * scale
+        return scaled_x
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the spatial gate module.

+ +
+ Source code in odak/learn/models/components.py +
def __init__(self):
+    """
+    Initializes the spatial gate module.
+    """
+    super().__init__()
+    kernel_size = 7
+    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())
+
+
+
+ +
+ +
+ + +

+ channel_pool(x) + +

+ + +
+ +

Applies max and average pooling on the channels.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input tensor.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Output tensor.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def channel_pool(self, x):
+    """
+    Applies max and average pooling on the channels.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input tensor.
+
+    Returns
+    -------
+    output        : torch.tensor
+                    Output tensor.
+    """
+    max_pool = torch.max(x, 1)[0].unsqueeze(1)
+    avg_pool = torch.mean(x, 1).unsqueeze(1)
+    output = torch.cat((max_pool, avg_pool), dim=1)
+    return output
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the SpatialGate module.

+

Applies spatial attention to the input tensor.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input tensor to the SpatialGate module.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +scaled_x ( tensor +) – +
    +

    Output tensor after applying spatial attention.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x):
+    """
+    Forward pass of the SpatialGate module.
+
+    Applies spatial attention to the input tensor.
+
+    Parameters
+    ----------
+    x            : torch.tensor
+                   Input tensor to the SpatialGate module.
+
+    Returns
+    -------
+    scaled_x     : torch.tensor
+                   Output tensor after applying spatial attention.
+    """
+    x_compress = self.channel_pool(x)
+    x_out = self.spatial(x_compress)
+    scale = torch.sigmoid(x_out)
+    scaled_x = x * scale
+    return scaled_x
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_convolution + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive convolution layer.

+ + +
+ References +

C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions." +C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation." +C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_convolution(torch.nn.Module):
+    """
+    A spatially adaptive convolution layer.
+
+    References
+    ----------
+
+    C. Zheng et al. "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions."
+    C. Xu et al. "Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation."
+    C. Zheng et al. "Windowing Decomposition Convolutional Neural Network for Image Enhancement."
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn.Module
+                          Activation function to apply. If None, no activation is applied.
+        """
+        super(spatially_adaptive_convolution, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive convolution layer.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        sa_output          : torch.tensor
+                            Estimated output tensor.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Resize weight to match the input channels and kernel size
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply spatially varying kernels
+        sv_feature = input_feature * sv_kernel
+
+        # Perform matrix multiplication
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+        return sa_output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Activation function to apply. If None, no activation is applied.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn.Module
+                      Activation function to apply. If None, no activation is applied.
+    """
+    super(spatially_adaptive_convolution, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive convolution layer.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sa_output ( tensor +) – +
    +

    Estimated output tensor. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive convolution layer.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    sa_output          : torch.tensor
+                        Estimated output tensor.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Resize weight to match the input channels and kernel size
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply spatially varying kernels
+    sv_feature = input_feature * sv_kernel
+
+    # Perform matrix multiplication
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+    return sa_output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_module + + +

+ + +
+

+ Bases: Module

+ + +

A spatially adaptive module that combines learned spatially adaptive convolutions.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/components.py +
class spatially_adaptive_module(torch.nn.Module):
+    """
+    A spatially adaptive module that combines learned spatially adaptive convolutions.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 input_channels = 2,
+                 output_channels = 2,
+                 kernel_size = 3,
+                 stride = 1,
+                 padding = 1,
+                 bias = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        Initializes a spatially adaptive module.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Size of the convolution kernel.
+        stride          : int
+                          Stride of the convolution.
+        padding         : int
+                          Padding added to both sides of the input.
+        bias            : bool
+                          If True, includes a bias term in the convolution.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        """
+        super(spatially_adaptive_module, self).__init__()
+        self.kernel_size = kernel_size
+        self.input_channels = input_channels
+        self.output_channels = output_channels
+        self.stride = stride
+        self.padding = padding
+        self.weight_output_channels = self.output_channels - 1
+        self.standard_convolution = torch.nn.Conv2d(
+                                                    in_channels = input_channels,
+                                                    out_channels = self.weight_output_channels,
+                                                    kernel_size = kernel_size,
+                                                    stride = stride,
+                                                    padding = padding,
+                                                    bias = bias
+                                                   )
+        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+        self.activation = activation
+
+
+    def forward(self, x, sv_kernel_feature):
+        """
+        Forward pass for the spatially adaptive module.
+
+        Parameters
+        ----------
+        x                  : torch.tensor
+                            Input data tensor.
+                            Dimension: (1, C, H, W)
+        sv_kernel_feature   : torch.tensor
+                            Spatially varying kernel features.
+                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+        Returns
+        -------
+        output             : torch.tensor
+                            Combined output tensor from standard and spatially adaptive convolutions.
+                            Dimension: (1, output_channels, H_out, W_out)
+        """
+        # Pad input and sv_kernel_feature if necessary
+        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+                -2) * self.stride != x.size(-2):
+            diffY = sv_kernel_feature.size(-2) % self.stride
+            diffX = sv_kernel_feature.size(-1) % self.stride
+            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                            diffY // 2, diffY - diffY // 2))
+            diffY = x.size(-2) % self.stride
+            diffX = x.size(-1) % self.stride
+            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                            diffY // 2, diffY - diffY // 2))
+
+        # Unfold the input tensor for matrix multiplication
+        input_feature = torch.nn.functional.unfold(
+                                                   x,
+                                                   kernel_size = (self.kernel_size, self.kernel_size),
+                                                   stride = self.stride,
+                                                   padding = self.padding
+                                                  )
+
+        # Resize sv_kernel_feature to match the input feature
+        sv_kernel = sv_kernel_feature.reshape(
+                                              1,
+                                              self.input_channels * self.kernel_size * self.kernel_size,
+                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                             )
+
+        # Apply sv_kernel to the input_feature
+        sv_feature = input_feature * sv_kernel
+
+        # Original spatially varying convolution output
+        sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                           1,
+                                                            1,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+        # Reshape weight for spatially adaptive convolution
+        si_kernel = self.weight.reshape(
+                                        self.weight_output_channels,
+                                        self.input_channels * self.kernel_size * self.kernel_size
+                                       )
+
+        # Apply si_kernel on sv convolution output
+        sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                                1, self.weight_output_channels,
+                                                                (x.size(-2) // self.stride),
+                                                                (x.size(-1) // self.stride)
+                                                               )
+
+        # Combine the outputs and apply activation function
+        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

Initializes a spatially adaptive module.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int, default: + 2 +) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Size of the convolution kernel.
    +
    +
    +
  • +
  • + stride + – +
    +
              Stride of the convolution.
    +
    +
    +
  • +
  • + padding + – +
    +
              Padding added to both sides of the input.
    +
    +
    +
  • +
  • + bias + – +
    +
              If True, includes a bias term in the convolution.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels = 2,
+             output_channels = 2,
+             kernel_size = 3,
+             stride = 1,
+             padding = 1,
+             bias = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    Initializes a spatially adaptive module.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Size of the convolution kernel.
+    stride          : int
+                      Stride of the convolution.
+    padding         : int
+                      Padding added to both sides of the input.
+    bias            : bool
+                      If True, includes a bias term in the convolution.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    """
+    super(spatially_adaptive_module, self).__init__()
+    self.kernel_size = kernel_size
+    self.input_channels = input_channels
+    self.output_channels = output_channels
+    self.stride = stride
+    self.padding = padding
+    self.weight_output_channels = self.output_channels - 1
+    self.standard_convolution = torch.nn.Conv2d(
+                                                in_channels = input_channels,
+                                                out_channels = self.weight_output_channels,
+                                                kernel_size = kernel_size,
+                                                stride = stride,
+                                                padding = padding,
+                                                bias = bias
+                                               )
+    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)
+    self.activation = activation
+
+
+
+ +
+ +
+ + +

+ forward(x, sv_kernel_feature) + +

+ + +
+ +

Forward pass for the spatially adaptive module.

+ + +

Parameters:

+
    +
  • + x + – +
    +
                Input data tensor.
    +            Dimension: (1, C, H, W)
    +
    +
    +
  • +
  • + sv_kernel_feature + – +
    +
                Spatially varying kernel features.
    +            Dimension: (1, C_i * kernel_size * kernel_size, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output ( tensor +) – +
    +

    Combined output tensor from standard and spatially adaptive convolutions. +Dimension: (1, output_channels, H_out, W_out)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x, sv_kernel_feature):
+    """
+    Forward pass for the spatially adaptive module.
+
+    Parameters
+    ----------
+    x                  : torch.tensor
+                        Input data tensor.
+                        Dimension: (1, C, H, W)
+    sv_kernel_feature   : torch.tensor
+                        Spatially varying kernel features.
+                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)
+
+    Returns
+    -------
+    output             : torch.tensor
+                        Combined output tensor from standard and spatially adaptive convolutions.
+                        Dimension: (1, output_channels, H_out, W_out)
+    """
+    # Pad input and sv_kernel_feature if necessary
+    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(
+            -2) * self.stride != x.size(-2):
+        diffY = sv_kernel_feature.size(-2) % self.stride
+        diffX = sv_kernel_feature.size(-1) % self.stride
+        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,
+                                                                        diffY // 2, diffY - diffY // 2))
+        diffY = x.size(-2) % self.stride
+        diffX = x.size(-1) % self.stride
+        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,
+                                        diffY // 2, diffY - diffY // 2))
+
+    # Unfold the input tensor for matrix multiplication
+    input_feature = torch.nn.functional.unfold(
+                                               x,
+                                               kernel_size = (self.kernel_size, self.kernel_size),
+                                               stride = self.stride,
+                                               padding = self.padding
+                                              )
+
+    # Resize sv_kernel_feature to match the input feature
+    sv_kernel = sv_kernel_feature.reshape(
+                                          1,
+                                          self.input_channels * self.kernel_size * self.kernel_size,
+                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)
+                                         )
+
+    # Apply sv_kernel to the input_feature
+    sv_feature = input_feature * sv_kernel
+
+    # Original spatially varying convolution output
+    sv_output = torch.sum(sv_feature, dim = 1).reshape(
+                                                       1,
+                                                        1,
+                                                        (x.size(-2) // self.stride),
+                                                        (x.size(-1) // self.stride)
+                                                       )
+
+    # Reshape weight for spatially adaptive convolution
+    si_kernel = self.weight.reshape(
+                                    self.weight_output_channels,
+                                    self.input_channels * self.kernel_size * self.kernel_size
+                                   )
+
+    # Apply si_kernel on sv convolution output
+    sa_output = torch.matmul(si_kernel, sv_feature).reshape(
+                                                            1, self.weight_output_channels,
+                                                            (x.size(-2) // self.stride),
+                                                            (x.size(-1) // self.stride)
+                                                           )
+
+    # Combine the outputs and apply activation function
+    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_adaptive_unet + + +

+ + +
+

+ Bases: Module

+ + +

Spatially varying U-Net model based on spatially adaptive convolution.

+ + +
+ References +

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

+
+ + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_adaptive_unet(torch.nn.Module):
+    """
+    Spatially varying U-Net model based on spatially adaptive convolution.
+
+    References
+    ----------
+
+    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Akşit, "Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions," SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.
+    """
+    def __init__(
+                 self,
+                 depth=3,
+                 dimensions=8,
+                 input_channels=6,
+                 out_channels=6,
+                 kernel_size=3,
+                 bias=True,
+                 normalization=False,
+                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        out_channels   : int
+                         Number of output channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.out_channels = out_channels
+        self.inc = convolution_layer(
+                                     input_channels=input_channels,
+                                     output_channels=dimensions,
+                                     kernel_size=kernel_size,
+                                     bias=bias,
+                                     normalization=normalization,
+                                     activation=activation
+                                    )
+
+        self.encoder = torch.nn.ModuleList()
+        for i in range(self.depth + 1):  # Downsampling layers
+            down_in_channels = dimensions * (2 ** i)
+            down_out_channels = 2 * down_in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels=down_in_channels,
+                                                          mid_channels=down_in_channels,
+                                                          output_channels=down_in_channels,
+                                                          kernel_size=kernel_size,
+                                                          bias=bias,
+                                                          normalization=normalization,
+                                                          activation=activation
+                                                         )
+            sam = spatially_adaptive_module(
+                                            input_channels=down_in_channels,
+                                            output_channels=down_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                           )
+            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+        self.global_feature_module = torch.nn.ModuleList()
+        double_convolution_layer = double_convolution(
+                                                      input_channels=dimensions * (2 ** (depth + 1)),
+                                                      mid_channels=dimensions * (2 ** (depth + 1)),
+                                                      output_channels=dimensions * (2 ** (depth + 1)),
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        global_feature_layer = global_feature_module(
+                                                     input_channels=dimensions * (2 ** (depth + 1)),
+                                                     mid_channels=dimensions * (2 ** (depth + 1)),
+                                                     output_channels=dimensions * (2 ** (depth + 1)),
+                                                     kernel_size=kernel_size,
+                                                     bias=bias,
+                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                    )
+        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+        self.decoder = torch.nn.ModuleList()
+        for i in range(depth, -1, -1):
+            up_in_channels = dimensions * (2 ** (i + 1))
+            up_mid_channels = up_in_channels // 2
+            if i == 0:
+                up_out_channels = self.out_channels
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = torch.nn.Sequential(
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_mid_channels,
+                                      kernel_size=kernel_size,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=activation,
+                                     ),
+                    convolution_layer(
+                                      input_channels=up_mid_channels,
+                                      output_channels=up_out_channels,
+                                      kernel_size=1,
+                                      bias=bias,
+                                      normalization=normalization,
+                                      activation=None,
+                                     )
+                )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+            else:
+                up_out_channels = up_in_channels // 2
+                upsample_layer = upsample_convtranspose2d_layer(
+                                                                input_channels=up_in_channels,
+                                                                output_channels=up_mid_channels,
+                                                                kernel_size=2,
+                                                                stride=2,
+                                                                bias=bias,
+                                                               )
+                conv_layer = double_convolution(
+                                                input_channels=up_mid_channels,
+                                                mid_channels=up_mid_channels,
+                                                output_channels=up_out_channels,
+                                                kernel_size=kernel_size,
+                                                bias=bias,
+                                                normalization=normalization,
+                                                activation=activation,
+                                               )
+                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, sv_kernel, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+
+        field     : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        target_field : torch.tensor
+                       Estimated output.
+                       Dimension: (1, 6, H, W)
+        """
+        x = self.inc(field)
+        downsampling_outputs = [x]
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer[0](downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+            downsampling_outputs.append(sam_output)
+        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+        downsampling_outputs.append(global_feature)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate(self.decoder):
+            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+            x_up = up_layer[1](x_up)
+        result = x_up
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 6 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + out_channels + – +
    +
             Number of output channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth=3,
+             dimensions=8,
+             input_channels=6,
+             out_channels=6,
+             kernel_size=3,
+             bias=True,
+             normalization=False,
+             activation=torch.nn.LeakyReLU(0.2, inplace=True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    out_channels   : int
+                     Number of output channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.out_channels = out_channels
+    self.inc = convolution_layer(
+                                 input_channels=input_channels,
+                                 output_channels=dimensions,
+                                 kernel_size=kernel_size,
+                                 bias=bias,
+                                 normalization=normalization,
+                                 activation=activation
+                                )
+
+    self.encoder = torch.nn.ModuleList()
+    for i in range(self.depth + 1):  # Downsampling layers
+        down_in_channels = dimensions * (2 ** i)
+        down_out_channels = 2 * down_in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels=down_in_channels,
+                                                      mid_channels=down_in_channels,
+                                                      output_channels=down_in_channels,
+                                                      kernel_size=kernel_size,
+                                                      bias=bias,
+                                                      normalization=normalization,
+                                                      activation=activation
+                                                     )
+        sam = spatially_adaptive_module(
+                                        input_channels=down_in_channels,
+                                        output_channels=down_out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                       )
+        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))
+    self.global_feature_module = torch.nn.ModuleList()
+    double_convolution_layer = double_convolution(
+                                                  input_channels=dimensions * (2 ** (depth + 1)),
+                                                  mid_channels=dimensions * (2 ** (depth + 1)),
+                                                  output_channels=dimensions * (2 ** (depth + 1)),
+                                                  kernel_size=kernel_size,
+                                                  bias=bias,
+                                                  normalization=normalization,
+                                                  activation=activation
+                                                 )
+    global_feature_layer = global_feature_module(
+                                                 input_channels=dimensions * (2 ** (depth + 1)),
+                                                 mid_channels=dimensions * (2 ** (depth + 1)),
+                                                 output_channels=dimensions * (2 ** (depth + 1)),
+                                                 kernel_size=kernel_size,
+                                                 bias=bias,
+                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)
+                                                )
+    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))
+    self.decoder = torch.nn.ModuleList()
+    for i in range(depth, -1, -1):
+        up_in_channels = dimensions * (2 ** (i + 1))
+        up_mid_channels = up_in_channels // 2
+        if i == 0:
+            up_out_channels = self.out_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = torch.nn.Sequential(
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_mid_channels,
+                                  kernel_size=kernel_size,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=activation,
+                                 ),
+                convolution_layer(
+                                  input_channels=up_mid_channels,
+                                  output_channels=up_out_channels,
+                                  kernel_size=1,
+                                  bias=bias,
+                                  normalization=normalization,
+                                  activation=None,
+                                 )
+            )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+        else:
+            up_out_channels = up_in_channels // 2
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels=up_in_channels,
+                                                            output_channels=up_mid_channels,
+                                                            kernel_size=2,
+                                                            stride=2,
+                                                            bias=bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels=up_mid_channels,
+                                            mid_channels=up_mid_channels,
+                                            output_channels=up_out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            normalization=normalization,
+                                            activation=activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(sv_kernel, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + sv_kernel + (list of torch.tensor) + – +
    +
        Learned spatially varying kernels.
    +    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
    +    where C_i, H_i, and W_i represent the channel, height, and width
    +    of each feature at a certain scale.
    +
    +
    +
  • +
  • + field + – +
    +
        Input field data.
    +    Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +target_field ( tensor +) – +
    +

    Estimated output. +Dimension: (1, 6, H, W)

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, sv_kernel, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+
+    field     : torch.tensor
+                Input field data.
+                Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    target_field : torch.tensor
+                   Estimated output.
+                   Dimension: (1, 6, H, W)
+    """
+    x = self.inc(field)
+    downsampling_outputs = [x]
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer[0](downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])
+        downsampling_outputs.append(sam_output)
+    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])
+    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)
+    downsampling_outputs.append(global_feature)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate(self.decoder):
+        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])
+        x_up = up_layer[1](x_up)
+    result = x_up
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ spatially_varying_kernel_generation_model + + +

+ + +
+

+ Bases: Module

+ + +

Spatially_varying_kernel_generation_model revised from RSGUnet: +https://github.com/MTLab/rsgunet_image_enhance.

+

Refer to: +J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class spatially_varying_kernel_generation_model(torch.nn.Module):
+    """
+    Spatially_varying_kernel_generation_model revised from RSGUnet:
+    https://github.com/MTLab/rsgunet_image_enhance.
+
+    Refer to:
+    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.
+    """
+
+    def __init__(
+                 self,
+                 depth = 3,
+                 dimensions = 8,
+                 input_channels = 7,
+                 kernel_size = 3,
+                 bias = True,
+                 normalization = False,
+                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth          : int
+                         Number of upsampling and downsampling layers.
+        dimensions     : int
+                         Number of dimensions.
+        input_channels : int
+                         Number of input channels.
+        bias           : bool
+                         Set to True to let convolutional layers learn a bias term.
+        normalization  : bool
+                         If True, adds a Batch Normalization layer after the convolutional layer.
+        activation     : torch.nn
+                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+        """
+        super().__init__()
+        self.depth = depth
+        self.inc = convolution_layer(
+                                     input_channels = input_channels,
+                                     output_channels = dimensions,
+                                     kernel_size = kernel_size,
+                                     bias = bias,
+                                     normalization = normalization,
+                                     activation = activation
+                                    )
+        self.encoder = torch.nn.ModuleList()
+        for i in range(depth + 1):  # downsampling layers
+            if i == 0:
+                in_channels = dimensions * (2 ** i)
+                out_channels = dimensions * (2 ** i)
+            elif i == depth:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = dimensions * (2 ** (i - 1))
+            else:
+                in_channels = dimensions * (2 ** (i - 1))
+                out_channels = 2 * in_channels
+            pooling_layer = torch.nn.AvgPool2d(2)
+            double_convolution_layer = double_convolution(
+                                                          input_channels = in_channels,
+                                                          mid_channels = in_channels,
+                                                          output_channels = out_channels,
+                                                          kernel_size = kernel_size,
+                                                          bias = bias,
+                                                          normalization = normalization,
+                                                          activation = activation
+                                                         )
+            self.encoder.append(pooling_layer)
+            self.encoder.append(double_convolution_layer)
+        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+        for i in range(depth, -1, -1):
+            if i == 1:
+                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+            else:
+                svf_in_channels = 2 ** (self.depth + i) + 1
+            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+            svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+            spatially_varying_kernel_generation = torch.nn.ModuleList()
+            for j in range(i, -1, -1):
+                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+                spatially_varying_kernel_generation.append(pooling_layer)
+            kernel_generation_block = torch.nn.Sequential(
+                torch.nn.Conv2d(
+                                in_channels = svf_in_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_mid_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+                activation,
+                torch.nn.Conv2d(
+                                in_channels = svf_mid_channels,
+                                out_channels = svf_out_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               ),
+            )
+            spatially_varying_kernel_generation.append(kernel_generation_block)
+            self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+        self.decoder = torch.nn.ModuleList()
+        global_feature_layer = global_feature_module(  # global feature layer
+                                                     input_channels = dimensions * (2 ** (depth - 1)),
+                                                     mid_channels = dimensions * (2 ** (depth - 1)),
+                                                     output_channels = dimensions * (2 ** (depth - 1)),
+                                                     kernel_size = kernel_size,
+                                                     bias = bias,
+                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                    )
+        self.decoder.append(global_feature_layer)
+        for i in range(depth, 0, -1):
+            if i == 2:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels
+                up_mid_channels = up_in_channels
+            elif i == 1:
+                up_in_channels = dimensions * 2
+                up_out_channels = dimensions
+                up_mid_channels = up_out_channels
+            else:
+                up_in_channels = (dimensions // 2) * (2 ** i)
+                up_out_channels = up_in_channels // 2
+                up_mid_channels = up_in_channels
+            upsample_layer = upsample_convtranspose2d_layer(
+                                                            input_channels = up_in_channels,
+                                                            output_channels = up_mid_channels,
+                                                            kernel_size = 2,
+                                                            stride = 2,
+                                                            bias = bias,
+                                                           )
+            conv_layer = double_convolution(
+                                            input_channels = up_mid_channels,
+                                            output_channels = up_out_channels,
+                                            kernel_size = kernel_size,
+                                            bias = bias,
+                                            normalization = normalization,
+                                            activation = activation,
+                                           )
+            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+    def forward(self, focal_surface, field):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        focal_surface : torch.tensor
+                        Input focal surface data.
+                        Dimension: (1, 1, H, W)
+
+        field         : torch.tensor
+                        Input field data.
+                        Dimension: (1, 6, H, W)
+
+        Returns
+        -------
+        sv_kernel : list of torch.tensor
+                    Learned spatially varying kernels.
+                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                    where C_i, H_i, and W_i represent the channel, height, and width
+                    of each feature at a certain scale.
+        """
+        x = self.inc(torch.cat((focal_surface, field), dim = 1))
+        downsampling_outputs = [focal_surface]
+        downsampling_outputs.append(x)
+        for i, down_layer in enumerate(self.encoder):
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        sv_kernels = []
+        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+            if i == 0:
+                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+                downsampling_outputs[-1] = global_feature
+                sv_feature = [global_feature, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                              sv_feature[3]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+            else:
+                x_up = up_layer[0](downsampling_outputs[-1],
+                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+                x_up = up_layer[1](x_up)
+                downsampling_outputs[-1] = x_up
+                sv_feature = [x_up, downsampling_outputs[0]]
+                for j in range(self.depth - i + 1):
+                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                    if j > 0:
+                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+                if i == 1:
+                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+                sv_kernels.append(sv_kernel)
+        return sv_kernels
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
             Number of upsampling and downsampling layers.
    +
    +
    +
  • +
  • + dimensions + – +
    +
             Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + (int, default: + 7 +) + – +
    +
             Number of input channels.
    +
    +
    +
  • +
  • + bias + – +
    +
             Set to True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + normalization + – +
    +
             If True, adds a Batch Normalization layer after the convolutional layer.
    +
    +
    +
  • +
  • + activation + – +
    +
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self,
+             depth = 3,
+             dimensions = 8,
+             input_channels = 7,
+             kernel_size = 3,
+             bias = True,
+             normalization = False,
+             activation = torch.nn.LeakyReLU(0.2, inplace = True)
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth          : int
+                     Number of upsampling and downsampling layers.
+    dimensions     : int
+                     Number of dimensions.
+    input_channels : int
+                     Number of input channels.
+    bias           : bool
+                     Set to True to let convolutional layers learn a bias term.
+    normalization  : bool
+                     If True, adds a Batch Normalization layer after the convolutional layer.
+    activation     : torch.nn
+                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).
+    """
+    super().__init__()
+    self.depth = depth
+    self.inc = convolution_layer(
+                                 input_channels = input_channels,
+                                 output_channels = dimensions,
+                                 kernel_size = kernel_size,
+                                 bias = bias,
+                                 normalization = normalization,
+                                 activation = activation
+                                )
+    self.encoder = torch.nn.ModuleList()
+    for i in range(depth + 1):  # downsampling layers
+        if i == 0:
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** i)
+        elif i == depth:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = dimensions * (2 ** (i - 1))
+        else:
+            in_channels = dimensions * (2 ** (i - 1))
+            out_channels = 2 * in_channels
+        pooling_layer = torch.nn.AvgPool2d(2)
+        double_convolution_layer = double_convolution(
+                                                      input_channels = in_channels,
+                                                      mid_channels = in_channels,
+                                                      output_channels = out_channels,
+                                                      kernel_size = kernel_size,
+                                                      bias = bias,
+                                                      normalization = normalization,
+                                                      activation = activation
+                                                     )
+        self.encoder.append(pooling_layer)
+        self.encoder.append(double_convolution_layer)
+    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation
+    for i in range(depth, -1, -1):
+        if i == 1:
+            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1
+        else:
+            svf_in_channels = 2 ** (self.depth + i) + 1
+        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)
+        svf_mid_channels = dimensions * (2 ** (self.depth - 1))
+        spatially_varying_kernel_generation = torch.nn.ModuleList()
+        for j in range(i, -1, -1):
+            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))
+            spatially_varying_kernel_generation.append(pooling_layer)
+        kernel_generation_block = torch.nn.Sequential(
+            torch.nn.Conv2d(
+                            in_channels = svf_in_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_mid_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+            activation,
+            torch.nn.Conv2d(
+                            in_channels = svf_mid_channels,
+                            out_channels = svf_out_channels,
+                            kernel_size = kernel_size,
+                            padding = kernel_size // 2,
+                            bias = bias
+                           ),
+        )
+        spatially_varying_kernel_generation.append(kernel_generation_block)
+        self.spatially_varying_feature.append(spatially_varying_kernel_generation)
+    self.decoder = torch.nn.ModuleList()
+    global_feature_layer = global_feature_module(  # global feature layer
+                                                 input_channels = dimensions * (2 ** (depth - 1)),
+                                                 mid_channels = dimensions * (2 ** (depth - 1)),
+                                                 output_channels = dimensions * (2 ** (depth - 1)),
+                                                 kernel_size = kernel_size,
+                                                 bias = bias,
+                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)
+                                                )
+    self.decoder.append(global_feature_layer)
+    for i in range(depth, 0, -1):
+        if i == 2:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels
+            up_mid_channels = up_in_channels
+        elif i == 1:
+            up_in_channels = dimensions * 2
+            up_out_channels = dimensions
+            up_mid_channels = up_out_channels
+        else:
+            up_in_channels = (dimensions // 2) * (2 ** i)
+            up_out_channels = up_in_channels // 2
+            up_mid_channels = up_in_channels
+        upsample_layer = upsample_convtranspose2d_layer(
+                                                        input_channels = up_in_channels,
+                                                        output_channels = up_mid_channels,
+                                                        kernel_size = 2,
+                                                        stride = 2,
+                                                        bias = bias,
+                                                       )
+        conv_layer = double_convolution(
+                                        input_channels = up_mid_channels,
+                                        output_channels = up_out_channels,
+                                        kernel_size = kernel_size,
+                                        bias = bias,
+                                        normalization = normalization,
+                                        activation = activation,
+                                       )
+        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))
+
+
+
+ +
+ +
+ + +

+ forward(focal_surface, field) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + focal_surface + (tensor) + – +
    +
            Input focal surface data.
    +        Dimension: (1, 1, H, W)
    +
    +
    +
  • +
  • + field + – +
    +
            Input field data.
    +        Dimension: (1, 6, H, W)
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sv_kernel ( list of torch.tensor +) – +
    +

    Learned spatially varying kernels. +Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), +where C_i, H_i, and W_i represent the channel, height, and width +of each feature at a certain scale.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, focal_surface, field):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    focal_surface : torch.tensor
+                    Input focal surface data.
+                    Dimension: (1, 1, H, W)
+
+    field         : torch.tensor
+                    Input field data.
+                    Dimension: (1, 6, H, W)
+
+    Returns
+    -------
+    sv_kernel : list of torch.tensor
+                Learned spatially varying kernels.
+                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),
+                where C_i, H_i, and W_i represent the channel, height, and width
+                of each feature at a certain scale.
+    """
+    x = self.inc(torch.cat((focal_surface, field), dim = 1))
+    downsampling_outputs = [focal_surface]
+    downsampling_outputs.append(x)
+    for i, down_layer in enumerate(self.encoder):
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    sv_kernels = []
+    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):
+        if i == 0:
+            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])
+            downsampling_outputs[-1] = global_feature
+            sv_feature = [global_feature, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],
+                          sv_feature[3]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+        else:
+            x_up = up_layer[0](downsampling_outputs[-1],
+                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])
+            x_up = up_layer[1](x_up)
+            downsampling_outputs[-1] = x_up
+            sv_feature = [x_up, downsampling_outputs[0]]
+            for j in range(self.depth - i + 1):
+                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])
+                if j > 0:
+                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))
+            if i == 1:
+                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]
+            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))
+            sv_kernels.append(sv_kernel)
+    return sv_kernels
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ unet + + +

+ + +
+

+ Bases: Module

+ + +

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

+ + + + + + +
+ Source code in odak/learn/models/models.py +
class unet(torch.nn.Module):
+    """
+    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.
+    """
+
+    def __init__(
+                 self, 
+                 depth = 4,
+                 dimensions = 64, 
+                 input_channels = 2, 
+                 output_channels = 1, 
+                 bilinear = False,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(inplace = True),
+                ):
+        """
+        U-Net model.
+
+        Parameters
+        ----------
+        depth             : int
+                            Number of upsampling and downsampling
+        dimensions        : int
+                            Number of dimensions.
+        input_channels    : int
+                            Number of input channels.
+        output_channels   : int
+                            Number of output channels.
+        bilinear          : bool
+                            Uses bilinear upsampling in upsampling layers when set True.
+        bias              : bool
+                            Set True to let convolutional layers learn a bias term.
+        activation        : torch.nn
+                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+        """
+        super(unet, self).__init__()
+        self.inc = double_convolution(
+                                      input_channels = input_channels,
+                                      mid_channels = dimensions,
+                                      output_channels = dimensions,
+                                      kernel_size = kernel_size,
+                                      bias = bias,
+                                      activation = activation
+                                     )      
+
+        self.downsampling_layers = torch.nn.ModuleList()
+        self.upsampling_layers = torch.nn.ModuleList()
+        for i in range(depth): # downsampling layers
+            in_channels = dimensions * (2 ** i)
+            out_channels = dimensions * (2 ** (i + 1))
+            down_layer = downsample_layer(in_channels,
+                                            out_channels,
+                                            kernel_size=kernel_size,
+                                            bias=bias,
+                                            activation=activation
+                                            )
+            self.downsampling_layers.append(down_layer)      
+
+        for i in range(depth - 1, -1, -1):  # upsampling layers
+            up_in_channels = dimensions * (2 ** (i + 1))  
+            up_out_channels = dimensions * (2 ** i) 
+            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+            self.upsampling_layers.append(up_layer)
+        self.outc = torch.nn.Conv2d(
+                                    dimensions, 
+                                    output_channels,
+                                    kernel_size = kernel_size,
+                                    padding = kernel_size // 2,
+                                    bias = bias
+                                   )
+
+
+    def forward(self, x):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x             : torch.tensor
+                        Input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Estimated output.      
+        """
+        downsampling_outputs = [self.inc(x)]
+        for down_layer in self.downsampling_layers:
+            x_down = down_layer(downsampling_outputs[-1])
+            downsampling_outputs.append(x_down)
+        x_up = downsampling_outputs[-1]
+        for i, up_layer in enumerate((self.upsampling_layers)):
+            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+        result = self.outc(x_up)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True)) + +

+ + +
+ +

U-Net model.

+ + +

Parameters:

+
    +
  • + depth + – +
    +
                Number of upsampling and downsampling
    +
    +
    +
  • +
  • + dimensions + – +
    +
                Number of dimensions.
    +
    +
    +
  • +
  • + input_channels + – +
    +
                Number of input channels.
    +
    +
    +
  • +
  • + output_channels + – +
    +
                Number of output channels.
    +
    +
    +
  • +
  • + bilinear + – +
    +
                Uses bilinear upsampling in upsampling layers when set True.
    +
    +
    +
  • +
  • + bias + – +
    +
                Set True to let convolutional layers learn a bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def __init__(
+             self, 
+             depth = 4,
+             dimensions = 64, 
+             input_channels = 2, 
+             output_channels = 1, 
+             bilinear = False,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(inplace = True),
+            ):
+    """
+    U-Net model.
+
+    Parameters
+    ----------
+    depth             : int
+                        Number of upsampling and downsampling
+    dimensions        : int
+                        Number of dimensions.
+    input_channels    : int
+                        Number of input channels.
+    output_channels   : int
+                        Number of output channels.
+    bilinear          : bool
+                        Uses bilinear upsampling in upsampling layers when set True.
+    bias              : bool
+                        Set True to let convolutional layers learn a bias term.
+    activation        : torch.nn
+                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().
+    """
+    super(unet, self).__init__()
+    self.inc = double_convolution(
+                                  input_channels = input_channels,
+                                  mid_channels = dimensions,
+                                  output_channels = dimensions,
+                                  kernel_size = kernel_size,
+                                  bias = bias,
+                                  activation = activation
+                                 )      
+
+    self.downsampling_layers = torch.nn.ModuleList()
+    self.upsampling_layers = torch.nn.ModuleList()
+    for i in range(depth): # downsampling layers
+        in_channels = dimensions * (2 ** i)
+        out_channels = dimensions * (2 ** (i + 1))
+        down_layer = downsample_layer(in_channels,
+                                        out_channels,
+                                        kernel_size=kernel_size,
+                                        bias=bias,
+                                        activation=activation
+                                        )
+        self.downsampling_layers.append(down_layer)      
+
+    for i in range(depth - 1, -1, -1):  # upsampling layers
+        up_in_channels = dimensions * (2 ** (i + 1))  
+        up_out_channels = dimensions * (2 ** i) 
+        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)
+        self.upsampling_layers.append(up_layer)
+    self.outc = torch.nn.Conv2d(
+                                dimensions, 
+                                output_channels,
+                                kernel_size = kernel_size,
+                                padding = kernel_size // 2,
+                                bias = bias
+                               )
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x + – +
    +
            Input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Estimated output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/models.py +
def forward(self, x):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x             : torch.tensor
+                    Input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Estimated output.      
+    """
+    downsampling_outputs = [self.inc(x)]
+    for down_layer in self.downsampling_layers:
+        x_down = down_layer(downsampling_outputs[-1])
+        downsampling_outputs.append(x_down)
+    x_up = downsampling_outputs[-1]
+    for i, up_layer in enumerate((self.upsampling_layers)):
+        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       
+    result = self.outc(x_up)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_convtranspose2d_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convtranspose2d layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_convtranspose2d_layer(torch.nn.Module):
+    """
+    An upsampling convtranspose2d layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 2,
+                 stride = 2,
+                 bias = False,
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool
+                          Set to True to let convolutional layers have bias term.
+        """
+        super().__init__()
+        self.up = torch.nn.ConvTranspose2d(
+                                           in_channels = input_channels,
+                                           out_channels = output_channels,
+                                           bias = bias,
+                                           kernel_size = kernel_size,
+                                           stride = stride
+                                          )
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        result = x1 + x2
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 2,
+             stride = 2,
+             bias = False,
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool
+                      Set to True to let convolutional layers have bias term.
+    """
+    super().__init__()
+    self.up = torch.nn.ConvTranspose2d(
+                                       in_channels = input_channels,
+                                       out_channels = output_channels,
+                                       bias = bias,
+                                       kernel_size = kernel_size,
+                                       stride = stride
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    result = x1 + x2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ upsample_layer + + +

+ + +
+

+ Bases: Module

+ + +

An upsampling convolutional layer.

+ + + + + + +
+ Source code in odak/learn/models/components.py +
class upsample_layer(torch.nn.Module):
+    """
+    An upsampling convolutional layer.
+    """
+    def __init__(
+                 self,
+                 input_channels,
+                 output_channels,
+                 kernel_size = 3,
+                 bias = False,
+                 activation = torch.nn.ReLU(),
+                 bilinear = True
+                ):
+        """
+        A downscaling component with a double convolution.
+
+        Parameters
+        ----------
+        input_channels  : int
+                          Number of input channels.
+        output_channels : int
+                          Number of output channels.
+        kernel_size     : int
+                          Kernel size.
+        bias            : bool 
+                          Set to True to let convolutional layers have bias term.
+        activation      : torch.nn
+                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+        bilinear        : bool
+                          If set to True, bilinear sampling is used.
+        """
+        super(upsample_layer, self).__init__()
+        if bilinear:
+            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+            self.conv = double_convolution(
+                                           input_channels = input_channels + output_channels,
+                                           mid_channels = input_channels // 2,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+        else:
+            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+            self.conv = double_convolution(
+                                           input_channels = input_channels,
+                                           mid_channels = output_channels,
+                                           output_channels = output_channels,
+                                           kernel_size = kernel_size,
+                                           bias = bias,
+                                           activation = activation
+                                          )
+
+
+    def forward(self, x1, x2):
+        """
+        Forward model.
+
+        Parameters
+        ----------
+        x1             : torch.tensor
+                         First input data.
+        x2             : torch.tensor
+                         Second input data.
+
+
+        Returns
+        ----------
+        result        : torch.tensor
+                        Result of the forward operation
+        """ 
+        x1 = self.up(x1)
+        diffY = x2.size()[2] - x1.size()[2]
+        diffX = x2.size()[3] - x1.size()[3]
+        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                          diffY // 2, diffY - diffY // 2])
+        x = torch.cat([x2, x1], dim = 1)
+        result = self.conv(x)
+        return result
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True) + +

+ + +
+ +

A downscaling component with a double convolution.

+ + +

Parameters:

+
    +
  • + input_channels + – +
    +
              Number of input channels.
    +
    +
    +
  • +
  • + output_channels + (int) + – +
    +
              Number of output channels.
    +
    +
    +
  • +
  • + kernel_size + – +
    +
              Kernel size.
    +
    +
    +
  • +
  • + bias + – +
    +
              Set to True to let convolutional layers have bias term.
    +
    +
    +
  • +
  • + activation + – +
    +
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
    +
    +
    +
  • +
  • + bilinear + – +
    +
              If set to True, bilinear sampling is used.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def __init__(
+             self,
+             input_channels,
+             output_channels,
+             kernel_size = 3,
+             bias = False,
+             activation = torch.nn.ReLU(),
+             bilinear = True
+            ):
+    """
+    A downscaling component with a double convolution.
+
+    Parameters
+    ----------
+    input_channels  : int
+                      Number of input channels.
+    output_channels : int
+                      Number of output channels.
+    kernel_size     : int
+                      Kernel size.
+    bias            : bool 
+                      Set to True to let convolutional layers have bias term.
+    activation      : torch.nn
+                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().
+    bilinear        : bool
+                      If set to True, bilinear sampling is used.
+    """
+    super(upsample_layer, self).__init__()
+    if bilinear:
+        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)
+        self.conv = double_convolution(
+                                       input_channels = input_channels + output_channels,
+                                       mid_channels = input_channels // 2,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+    else:
+        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)
+        self.conv = double_convolution(
+                                       input_channels = input_channels,
+                                       mid_channels = output_channels,
+                                       output_channels = output_channels,
+                                       kernel_size = kernel_size,
+                                       bias = bias,
+                                       activation = activation
+                                      )
+
+
+
+ +
+ +
+ + +

+ forward(x1, x2) + +

+ + +
+ +

Forward model.

+ + +

Parameters:

+
    +
  • + x1 + – +
    +
             First input data.
    +
    +
    +
  • +
  • + x2 + – +
    +
             Second input data.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( tensor +) – +
    +

    Result of the forward operation

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
def forward(self, x1, x2):
+    """
+    Forward model.
+
+    Parameters
+    ----------
+    x1             : torch.tensor
+                     First input data.
+    x2             : torch.tensor
+                     Second input data.
+
+
+    Returns
+    ----------
+    result        : torch.tensor
+                    Result of the forward operation
+    """ 
+    x1 = self.up(x1)
+    diffY = x2.size()[2] - x1.size()[2]
+    diffX = x2.size()[3] - x1.size()[3]
+    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
+                                      diffY // 2, diffY - diffY // 2])
+    x = torch.cat([x2, x1], dim = 1)
+    result = self.conv(x)
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ gaussian(x, multiplier=1.0) + +

+ + +
+ +

A Gaussian non-linear activation. +For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

+ + +

Parameters:

+
    +
  • + x + – +
    +
           Input data.
    +
    +
    +
  • +
  • + multiplier + – +
    +
           Multiplier.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( float or tensor +) – +
    +

    Ouput data.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def gaussian(x, multiplier = 1.):
+    """
+    A Gaussian non-linear activation.
+    For more details: Ramasinghe, Sameera, and Simon Lucey. "Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps." In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.
+
+    Parameters
+    ----------
+    x            : float or torch.tensor
+                   Input data.
+    multiplier   : float or torch.tensor
+                   Multiplier.
+
+    Returns
+    -------
+    result       : float or torch.tensor
+                   Ouput data.
+    """
+    result = torch.exp(- (multiplier * x) ** 2)
+    return result
+
+
+
+ +
+ +
+ + +

+ swish(x) + +

+ + +
+ +

A swish non-linear activation. +For more details: https://en.wikipedia.org/wiki/Swish_function

+ + +

Parameters:

+
    +
  • + x + – +
    +
             Input.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +out ( float or tensor +) – +
    +

    Output.

    +
    +
  • +
+ +
+ Source code in odak/learn/models/components.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def swish(x):
+    """
+    A swish non-linear activation.
+    For more details: https://en.wikipedia.org/wiki/Swish_function
+
+    Parameters
+    -----------
+    x              : float or torch.tensor
+                     Input.
+
+    Returns
+    -------
+    out            : float or torch.tensor
+                     Output.
+    """
+    out = x * torch.sigmoid(x)
+    return out
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ multi_color_hologram_optimizer + + +

+ + +
+ + +

A class for optimizing single or multi color holograms. +For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.

+ + + + + + +
+ Source code in odak/learn/wave/optimizers.py +
 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
class multi_color_hologram_optimizer():
+    """
+    A class for optimizing single or multi color holograms.
+    For more details, see Kavaklı et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.
+    """
+    def __init__(self,
+                 wavelengths,
+                 resolution,
+                 targets,
+                 propagator,
+                 number_of_frames = 3,
+                 number_of_depth_layers = 1,
+                 learning_rate = 2e-2,
+                 learning_rate_floor = 5e-3,
+                 double_phase = True,
+                 scale_factor = 1,
+                 method = 'multi-color',
+                 channel_power_filename = '',
+                 device = None,
+                 loss_function = None,
+                 peak_amplitude = 1.0,
+                 optimize_peak_amplitude = False,
+                 img_loss_thres = 2e-3,
+                 reduction = 'sum'
+                ):
+        self.device = device
+        if isinstance(self.device, type(None)):
+            self.device = torch.device("cpu")
+        torch.cuda.empty_cache()
+        torch.random.seed()
+        self.wavelengths = wavelengths
+        self.resolution = resolution
+        self.targets = targets
+        if propagator.propagation_type != 'Impulse Response Fresnel':
+            scale_factor = 1
+        self.scale_factor = scale_factor
+        self.propagator = propagator
+        self.learning_rate = learning_rate
+        self.learning_rate_floor = learning_rate_floor
+        self.number_of_channels = len(self.wavelengths)
+        self.number_of_frames = number_of_frames
+        self.number_of_depth_layers = number_of_depth_layers
+        self.double_phase = double_phase
+        self.channel_power_filename = channel_power_filename
+        self.method = method
+        if self.method != 'conventional' and self.method != 'multi-color':
+           logging.warning('Unknown optimization method. Options are conventional or multi-color.')
+           import sys
+           sys.exit()
+        self.peak_amplitude = peak_amplitude
+        self.optimize_peak_amplitude = optimize_peak_amplitude
+        if self.optimize_peak_amplitude:
+            self.init_peak_amplitude_scale()
+        self.img_loss_thres = img_loss_thres
+        self.kernels = []
+        self.init_phase()
+        self.init_channel_power()
+        self.init_loss_function(loss_function, reduction = reduction)
+        self.init_amplitude()
+        self.init_phase_scale()
+
+
+    def init_peak_amplitude_scale(self):
+        """
+        Internal function to set the phase scale.
+        """
+        self.peak_amplitude = torch.tensor(
+                                           self.peak_amplitude,
+                                           requires_grad = True,
+                                           device=self.device
+                                          )
+
+
+    def init_phase_scale(self):
+        """
+        Internal function to set the phase scale.
+        """
+        if self.method == 'conventional':
+            self.phase_scale = torch.tensor(
+                                            [
+                                             1.,
+                                             1.,
+                                             1.
+                                            ],
+                                            requires_grad = False,
+                                            device = self.device
+                                           )
+        if self.method == 'multi-color':
+            self.phase_scale = torch.tensor(
+                                            [
+                                             1.,
+                                             1.,
+                                             1.
+                                            ],
+                                            requires_grad = False,
+                                            device = self.device
+                                           )
+
+
+    def init_amplitude(self):
+        """
+        Internal function to set the amplitude of the illumination source.
+        """
+        self.amplitude = torch.zeros(
+                                     self.resolution[0] * self.scale_factor,
+                                     self.resolution[1] * self.scale_factor,
+                                     requires_grad = False,
+                                     device = self.device
+                                    )
+        self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.
+
+
+    def init_phase(self):
+        """
+        Internal function to set the starting phase of the phase-only hologram.
+        """
+        self.phase = torch.zeros(
+                                 self.number_of_frames,
+                                 self.resolution[0],
+                                 self.resolution[1],
+                                 device = self.device,
+                                 requires_grad = True
+                                )
+        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)
+
+
+    def init_channel_power(self):
+        """
+        Internal function to set the starting phase of the phase-only hologram.
+        """
+        if self.method == 'conventional':
+            logging.warning('Scheme: Conventional')
+            self.channel_power = torch.eye(
+                                           self.number_of_frames,
+                                           self.number_of_channels,
+                                           device = self.device,
+                                           requires_grad = False
+                                          )
+
+        elif self.method == 'multi-color':
+            logging.warning('Scheme: Multi-color')
+            self.channel_power = torch.ones(
+                                            self.number_of_frames,
+                                            self.number_of_channels,
+                                            device = self.device,
+                                            requires_grad = True
+                                           )
+        if self.channel_power_filename != '':
+            self.channel_power = torch_load(self.channel_power_filename).to(self.device)
+            self.channel_power.requires_grad = False
+            self.channel_power[self.channel_power < 0.] = 0.
+            self.channel_power[self.channel_power > 1.] = 1.
+            if self.method == 'multi-color':
+                self.channel_power.requires_grad = True
+            if self.method == 'conventional':
+                self.channel_power = torch.abs(torch.cos(self.channel_power))
+            logging.warning('Channel powers:')
+            logging.warning(self.channel_power)
+            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
+        self.propagator.set_laser_powers(self.channel_power)
+
+
+
+    def init_optimizer(self):
+        """
+        Internal function to set the optimizer.
+        """
+        optimization_variables = [self.phase, self.offset]
+        if self.optimize_peak_amplitude:
+            optimization_variables.append(self.peak_amplitude)
+        if self.method == 'multi-color':
+            optimization_variables.append(self.propagator.channel_power)
+        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)
+
+
+    def init_loss_function(self, loss_function, reduction = 'sum'):
+        """
+        Internal function to set the loss function.
+        """
+        self.l2_loss = torch.nn.MSELoss(reduction = reduction)
+        self.loss_type = 'custom'
+        self.loss_function = loss_function
+        if isinstance(self.loss_function, type(None)):
+            self.loss_type = 'conventional'
+            self.loss_function = torch.nn.MSELoss(reduction = reduction)
+
+
+
+    def evaluate(self, input_image, target_image, plane_id = 0):
+        """
+        Internal function to evaluate the loss.
+        """
+        if self.loss_type == 'conventional':
+            loss = self.loss_function(input_image, target_image)
+        elif self.loss_type == 'custom':
+            loss = 0
+            for i in range(len(self.wavelengths)):
+                loss += self.loss_function(
+                                           input_image[i],
+                                           target_image[i],
+                                           plane_id = plane_id
+                                          )
+        return loss
+
+
+    def double_phase_constrain(self, phase, phase_offset):
+        """
+        Internal function to constrain a given phase similarly to double phase encoding.
+
+        Parameters
+        ----------
+        phase                      : torch.tensor
+                                     Input phase values to be constrained.
+        phase_offset               : torch.tensor
+                                     Input phase offset value.
+
+        Returns
+        -------
+        phase_only                 : torch.tensor
+                                     Constrained output phase.
+        """
+        phase_zero_mean = phase - torch.mean(phase)
+        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
+        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
+        loss = multi_scale_total_variation_loss(phase_low, levels = 6)
+        loss += multi_scale_total_variation_loss(phase_high, levels = 6)
+        loss += torch.std(phase_low)
+        loss += torch.std(phase_high)
+        phase_only = torch.zeros_like(phase)
+        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
+        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
+        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
+        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
+        return phase_only, loss
+
+
+    def direct_phase_constrain(self, phase, phase_offset):
+        """
+        Internal function to constrain a given phase.
+
+        Parameters
+        ----------
+        phase                      : torch.tensor
+                                     Input phase values to be constrained.
+        phase_offset               : torch.tensor
+                                     Input phase offset value.
+
+        Returns
+        -------
+        phase_only                 : torch.tensor
+                                     Constrained output phase.
+        """
+        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
+        loss = multi_scale_total_variation_loss(phase, levels = 6)
+        loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
+        return phase_only, loss
+
+
+    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
+        """
+        Function to optimize multiplane phase-only holograms using stochastic gradient descent.
+
+        Parameters
+        ----------
+        number_of_iterations       : float
+                                     Number of iterations.
+        weights                    : list
+                                     Weights used in the loss function.
+
+        Returns
+        -------
+        hologram                   : torch.tensor
+                                     Optimised hologram.
+        """
+        hologram_phases = torch.zeros(
+                                      self.number_of_frames,
+                                      self.resolution[0],
+                                      self.resolution[1],
+                                      device = self.device
+                                     )
+        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
+        if self.optimize_peak_amplitude:
+            peak_amp_cache = self.peak_amplitude.item()
+        for step in t:
+            for g in self.optimizer.param_groups:
+                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
+                if g['lr'] < self.learning_rate_floor:
+                    g['lr'] = self.learning_rate_floor
+                learning_rate = g['lr']
+            total_loss = 0
+            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
+            for depth_id in t_depth:
+                self.optimizer.zero_grad()
+                depth_target = self.targets[depth_id]
+                reconstruction_intensities = torch.zeros(
+                                                         self.number_of_frames,
+                                                         self.number_of_channels,
+                                                         self.resolution[0] * self.scale_factor,
+                                                         self.resolution[1] * self.scale_factor,
+                                                         device = self.device
+                                                        )
+                loss_variation_hologram = 0
+                laser_powers = self.propagator.get_laser_powers()
+                for frame_id in range(self.number_of_frames):
+                    if self.double_phase:
+                        phase, loss_phase = self.double_phase_constrain(
+                                                                        self.phase[frame_id],
+                                                                        self.offset[frame_id]
+                                                                       )
+                    else:
+                        phase, loss_phase = self.direct_phase_constrain(
+                                                                        self.phase[frame_id],
+                                                                        self.offset[frame_id]
+                                                                       )
+                    loss_variation_hologram += loss_phase
+                    for channel_id in range(self.number_of_channels):
+                        phase_scaled = torch.zeros_like(self.amplitude)
+                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
+                        laser_power = laser_powers[frame_id][channel_id]
+                        hologram = generate_complex_field(
+                                                          laser_power * self.amplitude,
+                                                          phase_scaled * self.phase_scale[channel_id]
+                                                         )
+                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)
+                        intensity = calculate_amplitude(reconstruction_field) ** 2
+                        reconstruction_intensities[frame_id, channel_id] += intensity
+                    hologram_phases[frame_id] = phase.detach().clone()
+                loss_laser = self.l2_loss(
+                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
+                                          torch.sum(laser_powers, dim = 0)
+                                         )
+                loss_laser += self.l2_loss(
+                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
+                                           torch.sum(laser_powers).view(1,)
+                                          )
+                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
+                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
+                loss_image = self.evaluate(
+                                           reconstruction_intensity,
+                                           depth_target * self.peak_amplitude,
+                                           plane_id = depth_id
+                                          )
+                loss = weights[0] * loss_image
+                loss += weights[1] * loss_laser
+                loss += weights[2] * loss_variation_hologram
+                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
+                if include_pa_loss_flag:
+                    loss -= self.peak_amplitude * 1.
+                if self.method == 'conventional':
+                    loss.backward()
+                else:
+                    loss.backward(retain_graph = True)
+                self.optimizer.step()
+                if include_pa_loss_flag:
+                    peak_amp_cache = self.peak_amplitude.item()
+                else:
+                    with torch.no_grad():
+                        if self.optimize_peak_amplitude:
+                            self.peak_amplitude.view([1])[0] = peak_amp_cache
+                total_loss += loss.detach().item()
+                loss_image = loss_image.detach()
+                del loss_laser
+                del loss_variation_hologram
+                del loss
+            description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
+            t.set_description(description)
+            del total_loss
+            del loss_image
+            del reconstruction_field
+            del reconstruction_intensities
+            del intensity
+            del phase
+            del hologram
+        logging.warning(description)
+        return hologram_phases.detach()
+
+
+    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
+        """
+        Function to optimize multiplane phase-only holograms.
+
+        Parameters
+        ----------
+        number_of_iterations       : int
+                                     Number of iterations.
+        weights                    : list
+                                     Loss weights.
+        bits                       : int
+                                     Quantizes the hologram using the given bits and reconstructs.
+
+        Returns
+        -------
+        hologram_phases            : torch.tensor
+                                     Phases of the optimized phase-only hologram.
+        reconstruction_intensities : torch.tensor
+                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
+        """
+        self.init_optimizer()
+        hologram_phases = self.gradient_descent(
+                                                number_of_iterations=number_of_iterations,
+                                                weights=weights
+                                               )
+        hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
+        torch.no_grad()
+        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
+        laser_powers = self.propagator.get_laser_powers()
+        channel_powers = self.propagator.channel_power
+        logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
+        logging.warning('Laser powers: {}'.format(laser_powers))
+        return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ direct_phase_constrain(phase, phase_offset) + +

+ + +
+ +

Internal function to constrain a given phase.

+ + +

Parameters:

+
    +
  • + phase + – +
    +
                         Input phase values to be constrained.
    +
    +
    +
  • +
  • + phase_offset + – +
    +
                         Input phase offset value.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +phase_only ( tensor +) – +
    +

    Constrained output phase.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/optimizers.py +
def direct_phase_constrain(self, phase, phase_offset):
+    """
+    Internal function to constrain a given phase.
+
+    Parameters
+    ----------
+    phase                      : torch.tensor
+                                 Input phase values to be constrained.
+    phase_offset               : torch.tensor
+                                 Input phase offset value.
+
+    Returns
+    -------
+    phase_only                 : torch.tensor
+                                 Constrained output phase.
+    """
+    phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)
+    loss = multi_scale_total_variation_loss(phase, levels = 6)
+    loss += multi_scale_total_variation_loss(phase_offset, levels = 6)
+    return phase_only, loss
+
+
+
+ +
+ +
+ + +

+ double_phase_constrain(phase, phase_offset) + +

+ + +
+ +

Internal function to constrain a given phase similarly to double phase encoding.

+ + +

Parameters:

+
    +
  • + phase + – +
    +
                         Input phase values to be constrained.
    +
    +
    +
  • +
  • + phase_offset + – +
    +
                         Input phase offset value.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +phase_only ( tensor +) – +
    +

    Constrained output phase.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/optimizers.py +
def double_phase_constrain(self, phase, phase_offset):
+    """
+    Internal function to constrain a given phase similarly to double phase encoding.
+
+    Parameters
+    ----------
+    phase                      : torch.tensor
+                                 Input phase values to be constrained.
+    phase_offset               : torch.tensor
+                                 Input phase offset value.
+
+    Returns
+    -------
+    phase_only                 : torch.tensor
+                                 Constrained output phase.
+    """
+    phase_zero_mean = phase - torch.mean(phase)
+    phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)
+    phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)
+    loss = multi_scale_total_variation_loss(phase_low, levels = 6)
+    loss += multi_scale_total_variation_loss(phase_high, levels = 6)
+    loss += torch.std(phase_low)
+    loss += torch.std(phase_high)
+    phase_only = torch.zeros_like(phase)
+    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]
+    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]
+    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]
+    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]
+    return phase_only, loss
+
+
+
+ +
+ +
+ + +

+ evaluate(input_image, target_image, plane_id=0) + +

+ + +
+ +

Internal function to evaluate the loss.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def evaluate(self, input_image, target_image, plane_id = 0):
+    """
+    Internal function to evaluate the loss.
+    """
+    if self.loss_type == 'conventional':
+        loss = self.loss_function(input_image, target_image)
+    elif self.loss_type == 'custom':
+        loss = 0
+        for i in range(len(self.wavelengths)):
+            loss += self.loss_function(
+                                       input_image[i],
+                                       target_image[i],
+                                       plane_id = plane_id
+                                      )
+    return loss
+
+
+
+ +
+ +
+ + +

+ gradient_descent(number_of_iterations=100, weights=[1.0, 1.0, 0.0, 0.0]) + +

+ + +
+ +

Function to optimize multiplane phase-only holograms using stochastic gradient descent.

+ + +

Parameters:

+
    +
  • + number_of_iterations + – +
    +
                         Number of iterations.
    +
    +
    +
  • +
  • + weights + – +
    +
                         Weights used in the loss function.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( tensor +) – +
    +

    Optimised hologram.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/optimizers.py +
def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):
+    """
+    Function to optimize multiplane phase-only holograms using stochastic gradient descent.
+
+    Parameters
+    ----------
+    number_of_iterations       : float
+                                 Number of iterations.
+    weights                    : list
+                                 Weights used in the loss function.
+
+    Returns
+    -------
+    hologram                   : torch.tensor
+                                 Optimised hologram.
+    """
+    hologram_phases = torch.zeros(
+                                  self.number_of_frames,
+                                  self.resolution[0],
+                                  self.resolution[1],
+                                  device = self.device
+                                 )
+    t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)
+    if self.optimize_peak_amplitude:
+        peak_amp_cache = self.peak_amplitude.item()
+    for step in t:
+        for g in self.optimizer.param_groups:
+            g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations
+            if g['lr'] < self.learning_rate_floor:
+                g['lr'] = self.learning_rate_floor
+            learning_rate = g['lr']
+        total_loss = 0
+        t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)
+        for depth_id in t_depth:
+            self.optimizer.zero_grad()
+            depth_target = self.targets[depth_id]
+            reconstruction_intensities = torch.zeros(
+                                                     self.number_of_frames,
+                                                     self.number_of_channels,
+                                                     self.resolution[0] * self.scale_factor,
+                                                     self.resolution[1] * self.scale_factor,
+                                                     device = self.device
+                                                    )
+            loss_variation_hologram = 0
+            laser_powers = self.propagator.get_laser_powers()
+            for frame_id in range(self.number_of_frames):
+                if self.double_phase:
+                    phase, loss_phase = self.double_phase_constrain(
+                                                                    self.phase[frame_id],
+                                                                    self.offset[frame_id]
+                                                                   )
+                else:
+                    phase, loss_phase = self.direct_phase_constrain(
+                                                                    self.phase[frame_id],
+                                                                    self.offset[frame_id]
+                                                                   )
+                loss_variation_hologram += loss_phase
+                for channel_id in range(self.number_of_channels):
+                    phase_scaled = torch.zeros_like(self.amplitude)
+                    phase_scaled[::self.scale_factor, ::self.scale_factor] = phase
+                    laser_power = laser_powers[frame_id][channel_id]
+                    hologram = generate_complex_field(
+                                                      laser_power * self.amplitude,
+                                                      phase_scaled * self.phase_scale[channel_id]
+                                                     )
+                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)
+                    intensity = calculate_amplitude(reconstruction_field) ** 2
+                    reconstruction_intensities[frame_id, channel_id] += intensity
+                hologram_phases[frame_id] = phase.detach().clone()
+            loss_laser = self.l2_loss(
+                                      torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,
+                                      torch.sum(laser_powers, dim = 0)
+                                     )
+            loss_laser += self.l2_loss(
+                                       torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),
+                                       torch.sum(laser_powers).view(1,)
+                                      )
+            loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))
+            reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)
+            loss_image = self.evaluate(
+                                       reconstruction_intensity,
+                                       depth_target * self.peak_amplitude,
+                                       plane_id = depth_id
+                                      )
+            loss = weights[0] * loss_image
+            loss += weights[1] * loss_laser
+            loss += weights[2] * loss_variation_hologram
+            include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres
+            if include_pa_loss_flag:
+                loss -= self.peak_amplitude * 1.
+            if self.method == 'conventional':
+                loss.backward()
+            else:
+                loss.backward(retain_graph = True)
+            self.optimizer.step()
+            if include_pa_loss_flag:
+                peak_amp_cache = self.peak_amplitude.item()
+            else:
+                with torch.no_grad():
+                    if self.optimize_peak_amplitude:
+                        self.peak_amplitude.view([1])[0] = peak_amp_cache
+            total_loss += loss.detach().item()
+            loss_image = loss_image.detach()
+            del loss_laser
+            del loss_variation_hologram
+            del loss
+        description = "Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)
+        t.set_description(description)
+        del total_loss
+        del loss_image
+        del reconstruction_field
+        del reconstruction_intensities
+        del intensity
+        del phase
+        del hologram
+    logging.warning(description)
+    return hologram_phases.detach()
+
+
+
+ +
+ +
+ + +

+ init_amplitude() + +

+ + +
+ +

Internal function to set the amplitude of the illumination source.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_amplitude(self):
+    """
+    Internal function to set the amplitude of the illumination source.
+    """
+    self.amplitude = torch.zeros(
+                                 self.resolution[0] * self.scale_factor,
+                                 self.resolution[1] * self.scale_factor,
+                                 requires_grad = False,
+                                 device = self.device
+                                )
+    self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.
+
+
+
+ +
+ +
+ + +

+ init_channel_power() + +

+ + +
+ +

Internal function to set the starting phase of the phase-only hologram.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_channel_power(self):
+    """
+    Internal function to set the starting phase of the phase-only hologram.
+    """
+    if self.method == 'conventional':
+        logging.warning('Scheme: Conventional')
+        self.channel_power = torch.eye(
+                                       self.number_of_frames,
+                                       self.number_of_channels,
+                                       device = self.device,
+                                       requires_grad = False
+                                      )
+
+    elif self.method == 'multi-color':
+        logging.warning('Scheme: Multi-color')
+        self.channel_power = torch.ones(
+                                        self.number_of_frames,
+                                        self.number_of_channels,
+                                        device = self.device,
+                                        requires_grad = True
+                                       )
+    if self.channel_power_filename != '':
+        self.channel_power = torch_load(self.channel_power_filename).to(self.device)
+        self.channel_power.requires_grad = False
+        self.channel_power[self.channel_power < 0.] = 0.
+        self.channel_power[self.channel_power > 1.] = 1.
+        if self.method == 'multi-color':
+            self.channel_power.requires_grad = True
+        if self.method == 'conventional':
+            self.channel_power = torch.abs(torch.cos(self.channel_power))
+        logging.warning('Channel powers:')
+        logging.warning(self.channel_power)
+        logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))
+    self.propagator.set_laser_powers(self.channel_power)
+
+
+
+ +
+ +
+ + +

+ init_loss_function(loss_function, reduction='sum') + +

+ + +
+ +

Internal function to set the loss function.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_loss_function(self, loss_function, reduction = 'sum'):
+    """
+    Internal function to set the loss function.
+    """
+    self.l2_loss = torch.nn.MSELoss(reduction = reduction)
+    self.loss_type = 'custom'
+    self.loss_function = loss_function
+    if isinstance(self.loss_function, type(None)):
+        self.loss_type = 'conventional'
+        self.loss_function = torch.nn.MSELoss(reduction = reduction)
+
+
+
+ +
+ +
+ + +

+ init_optimizer() + +

+ + +
+ +

Internal function to set the optimizer.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_optimizer(self):
+    """
+    Internal function to set the optimizer.
+    """
+    optimization_variables = [self.phase, self.offset]
+    if self.optimize_peak_amplitude:
+        optimization_variables.append(self.peak_amplitude)
+    if self.method == 'multi-color':
+        optimization_variables.append(self.propagator.channel_power)
+    self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)
+
+
+
+ +
+ +
+ + +

+ init_peak_amplitude_scale() + +

+ + +
+ +

Internal function to set the phase scale.

+ +
+ Source code in odak/learn/wave/optimizers.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
def init_peak_amplitude_scale(self):
+    """
+    Internal function to set the phase scale.
+    """
+    self.peak_amplitude = torch.tensor(
+                                       self.peak_amplitude,
+                                       requires_grad = True,
+                                       device=self.device
+                                      )
+
+
+
+ +
+ +
+ + +

+ init_phase() + +

+ + +
+ +

Internal function to set the starting phase of the phase-only hologram.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_phase(self):
+    """
+    Internal function to set the starting phase of the phase-only hologram.
+    """
+    self.phase = torch.zeros(
+                             self.number_of_frames,
+                             self.resolution[0],
+                             self.resolution[1],
+                             device = self.device,
+                             requires_grad = True
+                            )
+    self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)
+
+
+
+ +
+ +
+ + +

+ init_phase_scale() + +

+ + +
+ +

Internal function to set the phase scale.

+ +
+ Source code in odak/learn/wave/optimizers.py +
def init_phase_scale(self):
+    """
+    Internal function to set the phase scale.
+    """
+    if self.method == 'conventional':
+        self.phase_scale = torch.tensor(
+                                        [
+                                         1.,
+                                         1.,
+                                         1.
+                                        ],
+                                        requires_grad = False,
+                                        device = self.device
+                                       )
+    if self.method == 'multi-color':
+        self.phase_scale = torch.tensor(
+                                        [
+                                         1.,
+                                         1.,
+                                         1.
+                                        ],
+                                        requires_grad = False,
+                                        device = self.device
+                                       )
+
+
+
+ +
+ +
+ + +

+ optimize(number_of_iterations=100, weights=[1.0, 1.0, 1.0], bits=8) + +

+ + +
+ +

Function to optimize multiplane phase-only holograms.

+ + +

Parameters:

+
    +
  • + number_of_iterations + – +
    +
                         Number of iterations.
    +
    +
    +
  • +
  • + weights + – +
    +
                         Loss weights.
    +
    +
    +
  • +
  • + bits + – +
    +
                         Quantizes the hologram using the given bits and reconstructs.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram_phases ( tensor +) – +
    +

    Phases of the optimized phase-only hologram.

    +
    +
  • +
  • +reconstruction_intensities ( tensor +) – +
    +

    Intensities of the images reconstructed at each plane with the optimized phase-only hologram.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/optimizers.py +
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):
+    """
+    Function to optimize multiplane phase-only holograms.
+
+    Parameters
+    ----------
+    number_of_iterations       : int
+                                 Number of iterations.
+    weights                    : list
+                                 Loss weights.
+    bits                       : int
+                                 Quantizes the hologram using the given bits and reconstructs.
+
+    Returns
+    -------
+    hologram_phases            : torch.tensor
+                                 Phases of the optimized phase-only hologram.
+    reconstruction_intensities : torch.tensor
+                                 Intensities of the images reconstructed at each plane with the optimized phase-only hologram.
+    """
+    self.init_optimizer()
+    hologram_phases = self.gradient_descent(
+                                            number_of_iterations=number_of_iterations,
+                                            weights=weights
+                                           )
+    hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi
+    torch.no_grad()
+    reconstruction_intensities = self.propagator.reconstruct(hologram_phases)
+    laser_powers = self.propagator.get_laser_powers()
+    channel_powers = self.propagator.channel_power
+    logging.warning("Final peak amplitude: {}".format(self.peak_amplitude))
+    logging.warning('Laser powers: {}'.format(laser_powers))
+    return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ propagator + + +

+ + +
+ + +

A light propagation model that propagates light to desired image plane with two separate propagations. +We use this class in our various works including Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography.

+ + + + + + +
+ Source code in odak/learn/wave/propagators.py +
  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
class propagator():
+    """
+    A light propagation model that propagates light to desired image plane with two separate propagations. 
+    We use this class in our various works including `Kavaklı et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography`.
+    """
+    def __init__(
+                 self,
+                 resolution = [1920, 1080],
+                 wavelengths = [515e-9,],
+                 pixel_pitch = 8e-6,
+                 resolution_factor = 1,
+                 number_of_frames = 1,
+                 number_of_depth_layers = 1,
+                 volume_depth = 1e-2,
+                 image_location_offset = 5e-3,
+                 propagation_type = 'Bandlimited Angular Spectrum',
+                 propagator_type = 'back and forth',
+                 back_and_forth_distance = 0.3,
+                 laser_channel_power = None,
+                 aperture = None,
+                 aperture_size = None,
+                 distances = None,
+                 aperture_samples = [20, 20, 5, 5],
+                 method = 'conventional',
+                 device = torch.device('cpu')
+                ):
+        """
+        Parameters
+        ----------
+        resolution              : list
+                                  Resolution.
+        wavelengths             : float
+                                  Wavelength of light in meters.
+        pixel_pitch             : float
+                                  Pixel pitch in meters.
+        resolution_factor       : int
+                                  Resolution factor for scaled simulations.
+        number_of_frames        : int
+                                  Number of hologram frames.
+                                  Typically, there are three frames, each one for a single color primary.
+        number_of_depth_layers  : int
+                                  Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
+        volume_depth            : float
+                                  Width of the volume along the propagation direction.
+        image_location_offset   : float
+                                  Center of the volume along the propagation direction.
+        propagation_type        : str
+                                  Propagation type. 
+                                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
+        propagator_type         : str
+                                  Propagator type.
+                                  The options are `back and forth` and `forward` propagators.
+        back_and_forth_distance : float
+                                  Zero mode distance for `back and forth` propagator type.
+        laser_channel_power     : torch.tensor
+                                  Laser channel powers for given number of frames and number of wavelengths.
+        aperture                : torch.tensor
+                                  Aperture at the Fourier plane.
+        aperture_size           : float
+                                  Aperture width for a circular aperture.
+        aperture_samples        : list
+                                  When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
+        distances               : torch.tensor
+                                  Propagation distances in meters.
+        method                  : str
+                                  Hologram type conventional or multi-color.
+        device                  : torch.device
+                                  Device to be used for computation. For more see torch.device().
+        """
+        self.device = device
+        self.pixel_pitch = pixel_pitch
+        self.wavelengths = wavelengths
+        self.resolution = resolution
+        self.propagation_type = propagation_type
+        if self.propagation_type != 'Impulse Response Fresnel':
+            resolution_factor = 1
+        self.resolution_factor = resolution_factor
+        self.number_of_frames = number_of_frames
+        self.number_of_depth_layers = number_of_depth_layers
+        self.number_of_channels = len(self.wavelengths)
+        self.volume_depth = volume_depth
+        self.image_location_offset = image_location_offset
+        self.propagator_type = propagator_type
+        self.aperture_samples = aperture_samples
+        self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)
+        self.method = method
+        self.aperture = aperture
+        self.init_distances(distances)
+        self.init_kernels()
+        self.init_channel_power(laser_channel_power)
+        self.init_phase_scale()
+        self.set_aperture(aperture, aperture_size)
+
+
+    def init_distances(self, distances):
+        """
+        Internal function to initialize distances.
+
+        Parameters
+        ----------
+        distances               : torch.tensor
+                                  Propagation distances.
+        """
+        if isinstance(distances, type(None)):
+            self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset
+        else:
+            self.distances = torch.as_tensor(distances)
+            self.number_of_depth_layers = self.distances.shape[0]
+        logging.warning('Distances: {}'.format(self.distances))
+
+
+    def init_kernels(self):
+        """
+        Internal function to initialize kernels.
+        """
+        self.generated_kernels = torch.zeros(
+                                             self.number_of_depth_layers,
+                                             self.number_of_channels,
+                                             device = self.device
+                                            )
+        self.kernels = torch.zeros(
+                                   self.number_of_depth_layers,
+                                   self.number_of_channels,
+                                   self.resolution[0] * self.resolution_factor * 2,
+                                   self.resolution[1] * self.resolution_factor * 2,
+                                   dtype = torch.complex64,
+                                   device = self.device
+                                  )
+
+
+    def init_channel_power(self, channel_power):
+        """
+        Internal function to set the starting phase of the phase-only hologram.
+        """
+        self.channel_power = channel_power
+        if isinstance(self.channel_power, type(None)):
+            self.channel_power = torch.eye(
+                                           self.number_of_frames,
+                                           self.number_of_channels,
+                                           device = self.device,
+                                           requires_grad = False
+                                          )
+
+
+    def init_phase_scale(self):
+        """
+        Internal function to set the phase scale.
+        In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.
+        """
+        self.phase_scale = torch.tensor(
+                                        [
+                                         1.,
+                                         1.,
+                                         1.
+                                        ],
+                                        requires_grad = False,
+                                        device = self.device
+                                       )
+
+
+    def set_aperture(self, aperture = None, aperture_size = None):
+        """
+        Set aperture in the Fourier plane.
+
+
+        Parameters
+        ----------
+        aperture        : torch.tensor
+                          Aperture at the original resolution of a hologram.
+                          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
+        aperture_size   : int
+                          If no aperture is provided, this will determine the size of the circular aperture.
+        """
+        if isinstance(aperture, type(None)):
+            if isinstance(aperture_size, type(None)):
+                aperture_size = torch.max(
+                                          torch.tensor([
+                                                        self.resolution[0] * self.resolution_factor, 
+                                                        self.resolution[1] * self.resolution_factor
+                                                       ])
+                                         )
+            self.aperture = circular_binary_mask(
+                                                 self.resolution[0] * self.resolution_factor * 2,
+                                                 self.resolution[1] * self.resolution_factor * 2,
+                                                 aperture_size,
+                                                ).to(self.device) * 1.
+        else:
+            self.aperture = zero_pad(aperture).to(self.device) * 1.
+
+
+    def get_laser_powers(self):
+        """
+        Internal function to get the laser powers.
+
+        Returns
+        -------
+        laser_power      : torch.tensor
+                           Laser powers.
+        """
+        if self.method == 'conventional':
+            laser_power = self.channel_power
+        if self.method == 'multi-color':
+            laser_power = torch.abs(torch.cos(self.channel_power))
+        return laser_power
+
+
+    def set_laser_powers(self, laser_power):
+        """
+        Internal function to set the laser powers.
+
+        Parameters
+        -------
+        laser_power      : torch.tensor
+                           Laser powers.
+        """
+        self.channel_power = laser_power
+
+
+
+    def get_kernels(self):
+        """
+        Function to return the kernels used in the light transport.
+
+        Returns
+        -------
+        kernels           : torch.tensor
+                            Kernel amplitudes.
+        """
+        h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))
+        kernels_amplitude = calculate_amplitude(h)
+        kernels_phase = calculate_phase(h)
+        return kernels_amplitude, kernels_phase
+
+
+    def __call__(self, input_field, channel_id, depth_id):
+        """
+        Function that represents the forward model in hologram optimization.
+
+        Parameters
+        ----------
+        input_field         : torch.tensor
+                              Input complex input field.
+        channel_id          : int
+                              Identifying the color primary to be used.
+        depth_id            : int
+                              Identifying the depth layer to be used.
+
+        Returns
+        -------
+        output_field        : torch.tensor
+                              Propagated output complex field.
+        """
+        distance = self.distances[depth_id]
+        if not self.generated_kernels[depth_id, channel_id]:
+            if self.propagator_type == 'forward':
+                H = get_propagation_kernel(
+                                           nu = self.resolution[0] * 2,
+                                           nv = self.resolution[1] * 2,
+                                           dx = self.pixel_pitch,
+                                           wavelength = self.wavelengths[channel_id],
+                                           distance = distance,
+                                           device = self.device,
+                                           propagation_type = self.propagation_type,
+                                           samples = self.aperture_samples,
+                                           scale = self.resolution_factor
+                                          )
+            elif self.propagator_type == 'back and forth':
+                H_forward = get_propagation_kernel(
+                                                   nu = self.resolution[0] * 2,
+                                                   nv = self.resolution[1] * 2,
+                                                   dx = self.pixel_pitch,
+                                                   wavelength = self.wavelengths[channel_id],
+                                                   distance = self.zero_mode_distance,
+                                                   device = self.device,
+                                                   propagation_type = self.propagation_type,
+                                                   samples = self.aperture_samples,
+                                                   scale = self.resolution_factor
+                                                  )
+                distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
+                H_back = get_propagation_kernel(
+                                                nu = self.resolution[0] * 2,
+                                                nv = self.resolution[1] * 2,
+                                                dx = self.pixel_pitch,
+                                                wavelength = self.wavelengths[channel_id],
+                                                distance = distance_back,
+                                                device = self.device,
+                                                propagation_type = self.propagation_type,
+                                                samples = self.aperture_samples,
+                                                scale = self.resolution_factor
+                                               )
+                H = H_forward * H_back
+            self.kernels[depth_id, channel_id] = H
+            self.generated_kernels[depth_id, channel_id] = True
+        else:
+            H = self.kernels[depth_id, channel_id].detach().clone()
+        field_scale = input_field
+        field_scale_padded = zero_pad(field_scale)
+        output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
+        output_field = crop_center(output_field_padded)
+        return output_field
+
+
+    def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):
+        """
+        Internal function to reconstruct a given hologram.
+
+
+        Parameters
+        ----------
+        hologram_phases            : torch.tensor
+                                     Hologram phases [ch x m x n].
+        amplitude                  : torch.tensor
+                                     Amplitude profiles for each color primary [ch x m x n]
+        no_grad                    : bool
+                                     If set True, uses torch.no_grad in reconstruction.
+        get_complex                : bool
+                                     If set True, reconstructor returns the complex field but not the intensities.
+
+        Returns
+        -------
+        reconstructions            : torch.tensor
+                                     Reconstructed frames.
+        """
+        if no_grad:
+            torch.no_grad()
+        if len(hologram_phases.shape) > 3:
+            hologram_phases = hologram_phases.squeeze(0)
+        if get_complex == True:
+            reconstruction_type = torch.complex64
+        else:
+            reconstruction_type = torch.float32
+        reconstructions = torch.zeros(
+                                      self.number_of_frames,
+                                      self.number_of_depth_layers,
+                                      self.number_of_channels,
+                                      self.resolution[0] * self.resolution_factor,
+                                      self.resolution[1] * self.resolution_factor,
+                                      dtype = reconstruction_type,
+                                      device = self.device
+                                     )
+        if isinstance(amplitude, type(None)):
+            amplitude = torch.zeros(
+                                    self.number_of_channels,
+                                    self.resolution[0] * self.resolution_factor,
+                                    self.resolution[1] * self.resolution_factor,
+                                    device = self.device
+                                   )
+            amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
+        if self.resolution_factor != 1:
+            hologram_phases_scaled = torch.zeros_like(amplitude)
+            hologram_phases_scaled[
+                                   :,
+                                   ::self.resolution_factor,
+                                   ::self.resolution_factor
+                                  ] = hologram_phases
+        else:
+            hologram_phases_scaled = hologram_phases
+        for frame_id in range(self.number_of_frames):
+            for depth_id in range(self.number_of_depth_layers):
+                for channel_id in range(self.number_of_channels):
+                    laser_power = self.get_laser_powers()[frame_id][channel_id]
+                    phase = hologram_phases_scaled[frame_id]
+                    hologram = generate_complex_field(
+                                                      laser_power * amplitude[channel_id],
+                                                      phase * self.phase_scale[channel_id]
+                                                     )
+                    reconstruction_field = self.__call__(hologram, channel_id, depth_id)
+                    if get_complex == True:
+                        result = reconstruction_field
+                    else:
+                        result = calculate_amplitude(reconstruction_field) ** 2
+                    reconstructions[
+                                    frame_id,
+                                    depth_id,
+                                    channel_id
+                                   ] = result.detach().clone()
+        return reconstructions
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(input_field, channel_id, depth_id) + +

+ + +
+ +

Function that represents the forward model in hologram optimization.

+ + +

Parameters:

+
    +
  • + input_field + – +
    +
                  Input complex input field.
    +
    +
    +
  • +
  • + channel_id + – +
    +
                  Identifying the color primary to be used.
    +
    +
    +
  • +
  • + depth_id + – +
    +
                  Identifying the depth layer to be used.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_field ( tensor +) – +
    +

    Propagated output complex field.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def __call__(self, input_field, channel_id, depth_id):
+    """
+    Function that represents the forward model in hologram optimization.
+
+    Parameters
+    ----------
+    input_field         : torch.tensor
+                          Input complex input field.
+    channel_id          : int
+                          Identifying the color primary to be used.
+    depth_id            : int
+                          Identifying the depth layer to be used.
+
+    Returns
+    -------
+    output_field        : torch.tensor
+                          Propagated output complex field.
+    """
+    distance = self.distances[depth_id]
+    if not self.generated_kernels[depth_id, channel_id]:
+        if self.propagator_type == 'forward':
+            H = get_propagation_kernel(
+                                       nu = self.resolution[0] * 2,
+                                       nv = self.resolution[1] * 2,
+                                       dx = self.pixel_pitch,
+                                       wavelength = self.wavelengths[channel_id],
+                                       distance = distance,
+                                       device = self.device,
+                                       propagation_type = self.propagation_type,
+                                       samples = self.aperture_samples,
+                                       scale = self.resolution_factor
+                                      )
+        elif self.propagator_type == 'back and forth':
+            H_forward = get_propagation_kernel(
+                                               nu = self.resolution[0] * 2,
+                                               nv = self.resolution[1] * 2,
+                                               dx = self.pixel_pitch,
+                                               wavelength = self.wavelengths[channel_id],
+                                               distance = self.zero_mode_distance,
+                                               device = self.device,
+                                               propagation_type = self.propagation_type,
+                                               samples = self.aperture_samples,
+                                               scale = self.resolution_factor
+                                              )
+            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)
+            H_back = get_propagation_kernel(
+                                            nu = self.resolution[0] * 2,
+                                            nv = self.resolution[1] * 2,
+                                            dx = self.pixel_pitch,
+                                            wavelength = self.wavelengths[channel_id],
+                                            distance = distance_back,
+                                            device = self.device,
+                                            propagation_type = self.propagation_type,
+                                            samples = self.aperture_samples,
+                                            scale = self.resolution_factor
+                                           )
+            H = H_forward * H_back
+        self.kernels[depth_id, channel_id] = H
+        self.generated_kernels[depth_id, channel_id] = True
+    else:
+        H = self.kernels[depth_id, channel_id].detach().clone()
+    field_scale = input_field
+    field_scale_padded = zero_pad(field_scale)
+    output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)
+    output_field = crop_center(output_field_padded)
+    return output_field
+
+
+
+ +
+ +
+ + +

+ __init__(resolution=[1920, 1080], wavelengths=[5.15e-07], pixel_pitch=8e-06, resolution_factor=1, number_of_frames=1, number_of_depth_layers=1, volume_depth=0.01, image_location_offset=0.005, propagation_type='Bandlimited Angular Spectrum', propagator_type='back and forth', back_and_forth_distance=0.3, laser_channel_power=None, aperture=None, aperture_size=None, distances=None, aperture_samples=[20, 20, 5, 5], method='conventional', device=torch.device('cpu')) + +

+ + +
+ + + +

Parameters:

+
    +
  • + resolution + – +
    +
                      Resolution.
    +
    +
    +
  • +
  • + wavelengths + – +
    +
                      Wavelength of light in meters.
    +
    +
    +
  • +
  • + pixel_pitch + – +
    +
                      Pixel pitch in meters.
    +
    +
    +
  • +
  • + resolution_factor + – +
    +
                      Resolution factor for scaled simulations.
    +
    +
    +
  • +
  • + number_of_frames + – +
    +
                      Number of hologram frames.
    +                  Typically, there are three frames, each one for a single color primary.
    +
    +
    +
  • +
  • + number_of_depth_layers + – +
    +
                      Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
    +
    +
    +
  • +
  • + volume_depth + – +
    +
                      Width of the volume along the propagation direction.
    +
    +
    +
  • +
  • + image_location_offset + – +
    +
                      Center of the volume along the propagation direction.
    +
    +
    +
  • +
  • + propagation_type + – +
    +
                      Propagation type. 
    +                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
    +
    +
    +
  • +
  • + propagator_type + – +
    +
                      Propagator type.
    +                  The options are `back and forth` and `forward` propagators.
    +
    +
    +
  • +
  • + back_and_forth_distance + (float, default: + 0.3 +) + – +
    +
                      Zero mode distance for `back and forth` propagator type.
    +
    +
    +
  • +
  • + laser_channel_power + – +
    +
                      Laser channel powers for given number of frames and number of wavelengths.
    +
    +
    +
  • +
  • + aperture + – +
    +
                      Aperture at the Fourier plane.
    +
    +
    +
  • +
  • + aperture_size + – +
    +
                      Aperture width for a circular aperture.
    +
    +
    +
  • +
  • + aperture_samples + – +
    +
                      When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
    +
    +
    +
  • +
  • + distances + – +
    +
                      Propagation distances in meters.
    +
    +
    +
  • +
  • + method + – +
    +
                      Hologram type conventional or multi-color.
    +
    +
    +
  • +
  • + device + – +
    +
                      Device to be used for computation. For more see torch.device().
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def __init__(
+             self,
+             resolution = [1920, 1080],
+             wavelengths = [515e-9,],
+             pixel_pitch = 8e-6,
+             resolution_factor = 1,
+             number_of_frames = 1,
+             number_of_depth_layers = 1,
+             volume_depth = 1e-2,
+             image_location_offset = 5e-3,
+             propagation_type = 'Bandlimited Angular Spectrum',
+             propagator_type = 'back and forth',
+             back_and_forth_distance = 0.3,
+             laser_channel_power = None,
+             aperture = None,
+             aperture_size = None,
+             distances = None,
+             aperture_samples = [20, 20, 5, 5],
+             method = 'conventional',
+             device = torch.device('cpu')
+            ):
+    """
+    Parameters
+    ----------
+    resolution              : list
+                              Resolution.
+    wavelengths             : float
+                              Wavelength of light in meters.
+    pixel_pitch             : float
+                              Pixel pitch in meters.
+    resolution_factor       : int
+                              Resolution factor for scaled simulations.
+    number_of_frames        : int
+                              Number of hologram frames.
+                              Typically, there are three frames, each one for a single color primary.
+    number_of_depth_layers  : int
+                              Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.
+    volume_depth            : float
+                              Width of the volume along the propagation direction.
+    image_location_offset   : float
+                              Center of the volume along the propagation direction.
+    propagation_type        : str
+                              Propagation type. 
+                              See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.
+    propagator_type         : str
+                              Propagator type.
+                              The options are `back and forth` and `forward` propagators.
+    back_and_forth_distance : float
+                              Zero mode distance for `back and forth` propagator type.
+    laser_channel_power     : torch.tensor
+                              Laser channel powers for given number of frames and number of wavelengths.
+    aperture                : torch.tensor
+                              Aperture at the Fourier plane.
+    aperture_size           : float
+                              Aperture width for a circular aperture.
+    aperture_samples        : list
+                              When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.
+    distances               : torch.tensor
+                              Propagation distances in meters.
+    method                  : str
+                              Hologram type conventional or multi-color.
+    device                  : torch.device
+                              Device to be used for computation. For more see torch.device().
+    """
+    self.device = device
+    self.pixel_pitch = pixel_pitch
+    self.wavelengths = wavelengths
+    self.resolution = resolution
+    self.propagation_type = propagation_type
+    if self.propagation_type != 'Impulse Response Fresnel':
+        resolution_factor = 1
+    self.resolution_factor = resolution_factor
+    self.number_of_frames = number_of_frames
+    self.number_of_depth_layers = number_of_depth_layers
+    self.number_of_channels = len(self.wavelengths)
+    self.volume_depth = volume_depth
+    self.image_location_offset = image_location_offset
+    self.propagator_type = propagator_type
+    self.aperture_samples = aperture_samples
+    self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)
+    self.method = method
+    self.aperture = aperture
+    self.init_distances(distances)
+    self.init_kernels()
+    self.init_channel_power(laser_channel_power)
+    self.init_phase_scale()
+    self.set_aperture(aperture, aperture_size)
+
+
+
+ +
+ +
+ + +

+ get_kernels() + +

+ + +
+ +

Function to return the kernels used in the light transport.

+ + +

Returns:

+
    +
  • +kernels ( tensor +) – +
    +

    Kernel amplitudes.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def get_kernels(self):
+    """
+    Function to return the kernels used in the light transport.
+
+    Returns
+    -------
+    kernels           : torch.tensor
+                        Kernel amplitudes.
+    """
+    h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))
+    kernels_amplitude = calculate_amplitude(h)
+    kernels_phase = calculate_phase(h)
+    return kernels_amplitude, kernels_phase
+
+
+
+ +
+ +
+ + +

+ get_laser_powers() + +

+ + +
+ +

Internal function to get the laser powers.

+ + +

Returns:

+
    +
  • +laser_power ( tensor +) – +
    +

    Laser powers.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def get_laser_powers(self):
+    """
+    Internal function to get the laser powers.
+
+    Returns
+    -------
+    laser_power      : torch.tensor
+                       Laser powers.
+    """
+    if self.method == 'conventional':
+        laser_power = self.channel_power
+    if self.method == 'multi-color':
+        laser_power = torch.abs(torch.cos(self.channel_power))
+    return laser_power
+
+
+
+ +
+ +
+ + +

+ init_channel_power(channel_power) + +

+ + +
+ +

Internal function to set the starting phase of the phase-only hologram.

+ +
+ Source code in odak/learn/wave/propagators.py +
def init_channel_power(self, channel_power):
+    """
+    Internal function to set the starting phase of the phase-only hologram.
+    """
+    self.channel_power = channel_power
+    if isinstance(self.channel_power, type(None)):
+        self.channel_power = torch.eye(
+                                       self.number_of_frames,
+                                       self.number_of_channels,
+                                       device = self.device,
+                                       requires_grad = False
+                                      )
+
+
+
+ +
+ +
+ + +

+ init_distances(distances) + +

+ + +
+ +

Internal function to initialize distances.

+ + +

Parameters:

+
    +
  • + distances + – +
    +
                      Propagation distances.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def init_distances(self, distances):
+    """
+    Internal function to initialize distances.
+
+    Parameters
+    ----------
+    distances               : torch.tensor
+                              Propagation distances.
+    """
+    if isinstance(distances, type(None)):
+        self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset
+    else:
+        self.distances = torch.as_tensor(distances)
+        self.number_of_depth_layers = self.distances.shape[0]
+    logging.warning('Distances: {}'.format(self.distances))
+
+
+
+ +
+ +
+ + +

+ init_kernels() + +

+ + +
+ +

Internal function to initialize kernels.

+ +
+ Source code in odak/learn/wave/propagators.py +
def init_kernels(self):
+    """
+    Internal function to initialize kernels.
+    """
+    self.generated_kernels = torch.zeros(
+                                         self.number_of_depth_layers,
+                                         self.number_of_channels,
+                                         device = self.device
+                                        )
+    self.kernels = torch.zeros(
+                               self.number_of_depth_layers,
+                               self.number_of_channels,
+                               self.resolution[0] * self.resolution_factor * 2,
+                               self.resolution[1] * self.resolution_factor * 2,
+                               dtype = torch.complex64,
+                               device = self.device
+                              )
+
+
+
+ +
+ +
+ + +

+ init_phase_scale() + +

+ + +
+ +

Internal function to set the phase scale. +In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.

+ +
+ Source code in odak/learn/wave/propagators.py +
def init_phase_scale(self):
+    """
+    Internal function to set the phase scale.
+    In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.
+    """
+    self.phase_scale = torch.tensor(
+                                    [
+                                     1.,
+                                     1.,
+                                     1.
+                                    ],
+                                    requires_grad = False,
+                                    device = self.device
+                                   )
+
+
+
+ +
+ +
+ + +

+ reconstruct(hologram_phases, amplitude=None, no_grad=True, get_complex=False) + +

+ + +
+ +

Internal function to reconstruct a given hologram.

+ + +

Parameters:

+
    +
  • + hologram_phases + – +
    +
                         Hologram phases [ch x m x n].
    +
    +
    +
  • +
  • + amplitude + – +
    +
                         Amplitude profiles for each color primary [ch x m x n]
    +
    +
    +
  • +
  • + no_grad + – +
    +
                         If set True, uses torch.no_grad in reconstruction.
    +
    +
    +
  • +
  • + get_complex + – +
    +
                         If set True, reconstructor returns the complex field but not the intensities.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +reconstructions ( tensor +) – +
    +

    Reconstructed frames.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):
+    """
+    Internal function to reconstruct a given hologram.
+
+
+    Parameters
+    ----------
+    hologram_phases            : torch.tensor
+                                 Hologram phases [ch x m x n].
+    amplitude                  : torch.tensor
+                                 Amplitude profiles for each color primary [ch x m x n]
+    no_grad                    : bool
+                                 If set True, uses torch.no_grad in reconstruction.
+    get_complex                : bool
+                                 If set True, reconstructor returns the complex field but not the intensities.
+
+    Returns
+    -------
+    reconstructions            : torch.tensor
+                                 Reconstructed frames.
+    """
+    if no_grad:
+        torch.no_grad()
+    if len(hologram_phases.shape) > 3:
+        hologram_phases = hologram_phases.squeeze(0)
+    if get_complex == True:
+        reconstruction_type = torch.complex64
+    else:
+        reconstruction_type = torch.float32
+    reconstructions = torch.zeros(
+                                  self.number_of_frames,
+                                  self.number_of_depth_layers,
+                                  self.number_of_channels,
+                                  self.resolution[0] * self.resolution_factor,
+                                  self.resolution[1] * self.resolution_factor,
+                                  dtype = reconstruction_type,
+                                  device = self.device
+                                 )
+    if isinstance(amplitude, type(None)):
+        amplitude = torch.zeros(
+                                self.number_of_channels,
+                                self.resolution[0] * self.resolution_factor,
+                                self.resolution[1] * self.resolution_factor,
+                                device = self.device
+                               )
+        amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.
+    if self.resolution_factor != 1:
+        hologram_phases_scaled = torch.zeros_like(amplitude)
+        hologram_phases_scaled[
+                               :,
+                               ::self.resolution_factor,
+                               ::self.resolution_factor
+                              ] = hologram_phases
+    else:
+        hologram_phases_scaled = hologram_phases
+    for frame_id in range(self.number_of_frames):
+        for depth_id in range(self.number_of_depth_layers):
+            for channel_id in range(self.number_of_channels):
+                laser_power = self.get_laser_powers()[frame_id][channel_id]
+                phase = hologram_phases_scaled[frame_id]
+                hologram = generate_complex_field(
+                                                  laser_power * amplitude[channel_id],
+                                                  phase * self.phase_scale[channel_id]
+                                                 )
+                reconstruction_field = self.__call__(hologram, channel_id, depth_id)
+                if get_complex == True:
+                    result = reconstruction_field
+                else:
+                    result = calculate_amplitude(reconstruction_field) ** 2
+                reconstructions[
+                                frame_id,
+                                depth_id,
+                                channel_id
+                               ] = result.detach().clone()
+    return reconstructions
+
+
+
+ +
+ +
+ + +

+ set_aperture(aperture=None, aperture_size=None) + +

+ + +
+ +

Set aperture in the Fourier plane.

+ + +

Parameters:

+
    +
  • + aperture + – +
    +
              Aperture at the original resolution of a hologram.
    +          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
    +
    +
    +
  • +
  • + aperture_size + – +
    +
              If no aperture is provided, this will determine the size of the circular aperture.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def set_aperture(self, aperture = None, aperture_size = None):
+    """
+    Set aperture in the Fourier plane.
+
+
+    Parameters
+    ----------
+    aperture        : torch.tensor
+                      Aperture at the original resolution of a hologram.
+                      If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).
+    aperture_size   : int
+                      If no aperture is provided, this will determine the size of the circular aperture.
+    """
+    if isinstance(aperture, type(None)):
+        if isinstance(aperture_size, type(None)):
+            aperture_size = torch.max(
+                                      torch.tensor([
+                                                    self.resolution[0] * self.resolution_factor, 
+                                                    self.resolution[1] * self.resolution_factor
+                                                   ])
+                                     )
+        self.aperture = circular_binary_mask(
+                                             self.resolution[0] * self.resolution_factor * 2,
+                                             self.resolution[1] * self.resolution_factor * 2,
+                                             aperture_size,
+                                            ).to(self.device) * 1.
+    else:
+        self.aperture = zero_pad(aperture).to(self.device) * 1.
+
+
+
+ +
+ +
+ + +

+ set_laser_powers(laser_power) + +

+ + +
+ +

Internal function to set the laser powers.

+ + +

Parameters:

+
    +
  • + laser_power + – +
    +
               Laser powers.
    +
    +
    +
  • +
+ +
+ Source code in odak/learn/wave/propagators.py +
def set_laser_powers(self, laser_power):
+    """
+    Internal function to set the laser powers.
+
+    Parameters
+    -------
+    laser_power      : torch.tensor
+                       Laser powers.
+    """
+    self.channel_power = laser_power
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ calculate_amplitude(field) + +

+ + +
+ +

Definition to calculate amplitude of a single or multiple given electric field(s).

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Electric fields or an electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +amplitude ( float +) – +
    +

    Amplitude or amplitudes of electric field(s).

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
def calculate_amplitude(field):
+    """ 
+    Definition to calculate amplitude of a single or multiple given electric field(s).
+
+    Parameters
+    ----------
+    field        : torch.cfloat
+                   Electric fields or an electric field.
+
+    Returns
+    -------
+    amplitude    : torch.float
+                   Amplitude or amplitudes of electric field(s).
+    """
+    amplitude = torch.abs(field)
+    return amplitude
+
+
+
+ +
+ +
+ + +

+ calculate_phase(field, deg=False) + +

+ + +
+ +

Definition to calculate phase of a single or multiple given electric field(s).

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Electric fields or an electric field.
    +
    +
    +
  • +
  • + deg + – +
    +
           If set True, the angles will be returned in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +phase ( float +) – +
    +

    Phase or phases of electric field(s) in radians.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def calculate_phase(field, deg = False):
+    """ 
+    Definition to calculate phase of a single or multiple given electric field(s).
+
+    Parameters
+    ----------
+    field        : torch.cfloat
+                   Electric fields or an electric field.
+    deg          : bool
+                   If set True, the angles will be returned in degrees.
+
+    Returns
+    -------
+    phase        : torch.float
+                   Phase or phases of electric field(s) in radians.
+    """
+    phase = field.imag.atan2(field.real)
+    if deg:
+        phase *= 180. / np.pi
+    return phase
+
+
+
+ +
+ +
+ + +

+ generate_complex_field(amplitude, phase) + +

+ + +
+ +

Definition to generate a complex field with a given amplitude and phase.

+ + +

Parameters:

+
    +
  • + amplitude + – +
    +
                Amplitude of the field.
    +            The expected size is [m x n] or [1 x m x n].
    +
    +
    +
  • +
  • + phase + – +
    +
                Phase of the field.
    +            The expected size is [m x n] or [1 x m x n].
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( ndarray +) – +
    +

    Complex field. +Depending on the input, the expected size is [m x n] or [1 x m x n].

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
def generate_complex_field(amplitude, phase):
+    """
+    Definition to generate a complex field with a given amplitude and phase.
+
+    Parameters
+    ----------
+    amplitude         : torch.tensor
+                        Amplitude of the field.
+                        The expected size is [m x n] or [1 x m x n].
+    phase             : torch.tensor
+                        Phase of the field.
+                        The expected size is [m x n] or [1 x m x n].
+
+    Returns
+    -------
+    field             : ndarray
+                        Complex field.
+                        Depending on the input, the expected size is [m x n] or [1 x m x n].
+    """
+    field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)
+    return field
+
+
+
+ +
+ +
+ + +

+ set_amplitude(field, amplitude) + +

+ + +
+ +

Definition to keep phase as is and change the amplitude of a given field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
  • + amplitude + – +
    +
           Amplitudes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( cfloat +) – +
    +

    Complex field.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
def set_amplitude(field, amplitude):
+    """
+    Definition to keep phase as is and change the amplitude of a given field.
+
+    Parameters
+    ----------
+    field        : torch.cfloat
+                   Complex field.
+    amplitude    : torch.cfloat or torch.float
+                   Amplitudes.
+
+    Returns
+    -------
+    new_field    : torch.cfloat
+                   Complex field.
+    """
+    amplitude = calculate_amplitude(amplitude)
+    phase = calculate_phase(field)
+    new_field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ wavenumber(wavelength) + +

+ + +
+ +

Definition for calculating the wavenumber of a plane wave.

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
           Wavelength of a wave in mm.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +k ( float +) – +
    +

    Wave number for a given wavelength.

    +
    +
  • +
+ +
+ Source code in odak/learn/wave/util.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
def wavenumber(wavelength):
+    """
+    Definition for calculating the wavenumber of a plane wave.
+
+    Parameters
+    ----------
+    wavelength   : float
+                   Wavelength of a wave in mm.
+
+    Returns
+    -------
+    k            : float
+                   Wave number for a given wavelength.
+    """
+    k = 2 * np.pi / wavelength
+    return k
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/raytracing/index.html b/odak/raytracing/index.html new file mode 100644 index 00000000..6153e2f5 --- /dev/null +++ b/odak/raytracing/index.html @@ -0,0 +1,9926 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.raytracing - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.raytracing

+ +
+ + + + +
+ +

odak.raytracing

+

Provides necessary definitions for geometric optics. See "General Ray tracing procedure" from G.H. Spencerand M.V.R.K Murty for the theoratical explanation.

+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ bring_plane_to_origin(point, plane, shape=[10.0, 10.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0], mode='XYZ') + +

+ + +
+ +

Definition to bring points back to reference origin with respect to a plane.

+ + +

Parameters:

+
    +
  • + point + – +
    +
                 Point(s) to be tested.
    +
    +
    +
  • +
  • + shape + – +
    +
                 Dimensions of the rectangle along X and Y axes.
    +
    +
    +
  • +
  • + center + – +
    +
                 Center of the rectangle.
    +
    +
    +
  • +
  • + angles + – +
    +
                 Rotation angle of the rectangle.
    +
    +
    +
  • +
  • + mode + – +
    +
                 Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +transformed_points ( ndarray +) – +
    +

    Point(s) that are brought back to reference origin with respect to given plane.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
def bring_plane_to_origin(point, plane, shape=[10., 10.], center=[0., 0., 0.], angles=[0., 0., 0.], mode='XYZ'):
+    """
+    Definition to bring points back to reference origin with respect to a plane.
+
+    Parameters
+    ----------
+    point              : ndarray
+                         Point(s) to be tested.
+    shape              : list
+                         Dimensions of the rectangle along X and Y axes.
+    center             : list
+                         Center of the rectangle.
+    angles             : list
+                         Rotation angle of the rectangle.
+    mode               : str
+                         Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.
+
+    Returns
+    ----------
+    transformed_points : ndarray
+                         Point(s) that are brought back to reference origin with respect to given plane.
+    """
+    if point.shape[0] == 3:
+        point = point.reshape((1, 3))
+    reverse_mode = mode[::-1]
+    angles = [-angles[0], -angles[1], -angles[2]]
+    center = np.asarray(center).reshape((1, 3))
+    transformed_points = point-center
+    transformed_points = rotate_points(
+        transformed_points,
+        angles=angles,
+        mode=reverse_mode,
+    )
+    if transformed_points.shape[0] == 1:
+        transformed_points = transformed_points.reshape((3,))
+    return transformed_points
+
+
+
+ +
+ +
+ + +

+ calculate_intersection_of_two_rays(ray0, ray1) + +

+ + +
+ +

Definition to calculate the intersection of two rays.

+ + +

Parameters:

+
    +
  • + ray0 + – +
    +
         A ray.
    +
    +
    +
  • +
  • + ray1 + – +
    +
         A ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point ( ndarray +) – +
    +

    Point in X,Y,Z.

    +
    +
  • +
  • +distances ( ndarray +) – +
    +

    Distances.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def calculate_intersection_of_two_rays(ray0, ray1):
+    """
+    Definition to calculate the intersection of two rays.
+
+    Parameters
+    ----------
+    ray0       : ndarray
+                 A ray.
+    ray1       : ndarray
+                 A ray.
+
+    Returns
+    ----------
+    point      : ndarray
+                 Point in X,Y,Z.
+    distances  : ndarray
+                 Distances.
+    """
+    A = np.array([
+        [float(ray0[1][0]), float(ray1[1][0])],
+        [float(ray0[1][1]), float(ray1[1][1])],
+        [float(ray0[1][2]), float(ray1[1][2])]
+    ])
+    B = np.array([
+        ray0[0][0]-ray1[0][0],
+        ray0[0][1]-ray1[0][1],
+        ray0[0][2]-ray1[0][2]
+    ])
+    distances = np.linalg.lstsq(A, B, rcond=None)[0]
+    if np.allclose(np.dot(A, distances), B) == False:
+        distances = np.array([0, 0])
+    distances = distances[np.argsort(-distances)]
+    point = propagate_a_ray(ray0, distances[0])[0]
+    return point, distances
+
+
+
+ +
+ +
+ + +

+ center_of_triangle(triangle) + +

+ + +
+ +

Definition to calculate center of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
            An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).
    +
    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
def center_of_triangle(triangle):
+    """
+    Definition to calculate center of a triangle.
+
+    Parameters
+    ----------
+    triangle      : ndarray
+                    An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.reshape((1, 3, 3))
+    center = np.mean(triangle, axis=1)
+    return center
+
+
+
+ +
+ +
+ + +

+ closest_point_to_a_ray(point, ray) + +

+ + +
+ +

Definition to calculate the point on a ray that is closest to given point.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Given point in X,Y,Z.
    +
    +
    +
  • +
  • + ray + – +
    +
            Given ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +closest_point ( ndarray +) – +
    +

    Calculated closest point.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def closest_point_to_a_ray(point, ray):
+    """
+    Definition to calculate the point on a ray that is closest to given point.
+
+    Parameters
+    ----------
+    point         : list
+                    Given point in X,Y,Z.
+    ray           : ndarray
+                    Given ray.
+
+    Returns
+    ---------
+    closest_point : ndarray
+                    Calculated closest point.
+    """
+    from odak.raytracing import propagate_a_ray
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    p0 = ray[:, 0]
+    p1 = propagate_a_ray(ray, 1.)
+    if len(p1.shape) == 2:
+        p1 = p1.reshape((1, 2, 3))
+    p1 = p1[:, 0]
+    p1 = p1.reshape(3)
+    p0 = p0.reshape(3)
+    point = point.reshape(3)
+    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)
+    closest_point = propagate_a_ray(ray, closest_distance)[0]
+    return closest_point
+
+
+
+ +
+ +
+ + +

+ create_ray(x0y0z0, abg) + +

+ + +
+ +

Definition to create a ray.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +
    +
    +
  • +
  • + abg + – +
    +
           List that contaings angles in degrees with respect to the X,Y and Z axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def create_ray(x0y0z0, abg):
+    """
+    Definition to create a ray.
+
+    Parameters
+    ----------
+    x0y0z0       : list
+                   List that contains X,Y and Z start locations of a ray.
+    abg          : list
+                   List that contaings angles in degrees with respect to the X,Y and Z axes.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    # Due to Python 2 -> Python 3.
+    x0, y0, z0 = x0y0z0
+    alpha, beta, gamma = abg
+    # Create a vector with the given points and angles in each direction
+    point = np.array([x0, y0, z0], dtype=np.float64)
+    alpha = np.cos(np.radians(alpha))
+    beta = np.cos(np.radians(beta))
+    gamma = np.cos(np.radians(gamma))
+    # Cosines vector.
+    cosines = np.array([alpha, beta, gamma], dtype=np.float64)
+    ray = np.array([point, cosines], dtype=np.float64)
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_angles(point, angles, mode='XYZ') + +

+ + +
+ +

Definition to create a ray from a point and angles.

+ + +

Parameters:

+
    +
  • + point + – +
    +
         Point in X,Y and Z.
    +
    +
    +
  • +
  • + angles + – +
    +
         Angles with X,Y,Z axes in degrees. All zeros point Z axis.
    +
    +
    +
  • +
  • + mode + – +
    +
         Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def create_ray_from_angles(point, angles, mode='XYZ'):
+    """
+    Definition to create a ray from a point and angles.
+
+    Parameters
+    ----------
+    point      : ndarray
+                 Point in X,Y and Z.
+    angles     : ndarray
+                 Angles with X,Y,Z axes in degrees. All zeros point Z axis.
+    mode       : str
+                 Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.
+
+    Returns
+    ----------
+    ray        : ndarray
+                 Created ray.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    new_point = np.zeros(point.shape)
+    new_point[:, 2] += 5.
+    new_point = rotate_points(new_point, angles, mode=mode, offset=point[:, 0])
+    ray = create_ray_from_two_points(point, new_point)
+    if ray.shape[0] == 1:
+        ray = ray.reshape((2, 3))
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_two_points(x0y0z0, x1y1z1) + +

+ + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : list
+                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
+    x1y1z1       : list
+                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)
+    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.reshape((1, 3))
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.reshape((1, 3))
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)
+    s[s == 0] = np.nan
+    cosines = np.zeros((xdiff.shape[0], 3))
+    cosines[:, 0] = xdiff/s
+    cosines[:, 1] = ydiff/s
+    cosines[:, 2] = zdiff/s
+    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    if ray.shape[0] == 1:
+        ray = ray.reshape((2, 3))
+    return ray
+
+
+
+ +
+ +
+ + +

+ cylinder_function(point, cylinder) + +

+ + +
+ +

Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html

+ + +

Parameters:

+
    +
  • + cylinder + – +
    +
         Cylinder parameters, XYZ center and radius.
    +
    +
    +
  • +
  • + point + – +
    +
         Point in XYZ.
    +
    +
    +
  • +
+ + +
+ Return +

result : float + Result of the evaluation. Zero if point is on sphere.

+
+
+ Source code in odak/raytracing/primitives.py +
def cylinder_function(point, cylinder):
+    """
+    Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
+
+    Parameters
+    ----------
+    cylinder   : ndarray
+                 Cylinder parameters, XYZ center and radius.
+    point      : ndarray
+                 Point in XYZ.
+
+    Return
+    ----------
+    result     : float
+                 Result of the evaluation. Zero if point is on sphere.
+    """
+    point = np.asarray(point)
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    distance = point_to_ray_distance(
+        point,
+        np.array([cylinder[0], cylinder[1], cylinder[2]], dtype=np.float64),
+        np.array([cylinder[4], cylinder[5], cylinder[6]], dtype=np.float64)
+    )
+    r = cylinder[3]
+    result = distance - r ** 2
+    return result
+
+
+
+ +
+ +
+ + +

+ define_circle(center, radius, angles) + +

+ + +
+ +

Definition to describe a circle in a single variable packed form.

+ + +

Parameters:

+
    +
  • + center + – +
    +
      Center of a circle to be defined.
    +
    +
    +
  • +
  • + radius + – +
    +
      Radius of a circle to be defined.
    +
    +
    +
  • +
  • + angles + – +
    +
      Angular tilt of a circle.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +circle ( list +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_circle(center, radius, angles):
+    """
+    Definition to describe a circle in a single variable packed form.
+
+    Parameters
+    ----------
+    center  : float
+              Center of a circle to be defined.
+    radius  : float
+              Radius of a circle to be defined.
+    angles  : float
+              Angular tilt of a circle.
+
+    Returns
+    ----------
+    circle  : list
+              Single variable packed form.
+    """
+    points = define_plane(center, angles=angles)
+    circle = [
+        points,
+        center,
+        radius
+    ]
+    return circle
+
+
+
+ +
+ +
+ + +

+ define_cylinder(center, radius, rotation=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to define a cylinder

+ + +

Parameters:

+
    +
  • + center + – +
    +
         Center of a cylinder in X,Y,Z.
    +
    +
    +
  • +
  • + radius + – +
    +
         Radius of a cylinder along X axis.
    +
    +
    +
  • +
  • + rotation + – +
    +
         Direction angles in degrees for the orientation of a cylinder.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cylinder ( ndarray +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_cylinder(center, radius, rotation=[0., 0., 0.]):
+    """
+    Definition to define a cylinder
+
+    Parameters
+    ----------
+    center     : ndarray
+                 Center of a cylinder in X,Y,Z.
+    radius     : float
+                 Radius of a cylinder along X axis.
+    rotation   : list
+                 Direction angles in degrees for the orientation of a cylinder.
+
+    Returns
+    ----------
+    cylinder   : ndarray
+                 Single variable packed form.
+    """
+    cylinder_ray = create_ray_from_angles(
+        np.asarray(center), np.asarray(rotation))
+    cylinder = np.array(
+        [
+            center[0],
+            center[1],
+            center[2],
+            radius,
+            center[0]+cylinder_ray[1, 0],
+            center[1]+cylinder_ray[1, 1],
+            center[2]+cylinder_ray[1, 2]
+        ],
+        dtype=np.float64
+    )
+    return cylinder
+
+
+
+ +
+ +
+ + +

+ define_plane(point, angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point that is at the center of a plane.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +plane ( ndarray +) – +
    +

    Points defining plane.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def define_plane(point, angles=[0., 0., 0.]):
+    """ 
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    point        : ndarray
+                   A point that is at the center of a plane.
+    angles       : list
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    plane        : ndarray
+                   Points defining plane.
+    """
+    plane = np.array([
+        [10., 10., 0.],
+        [0., 10., 0.],
+        [0.,  0., 0.]
+    ], dtype=np.float64)
+    point = np.asarray(point)
+    for i in range(0, plane.shape[0]):
+        plane[i], _, _, _ = rotate_point(plane[i], angles=angles)
+        plane[i] = plane[i]+point
+    return plane
+
+
+
+ +
+ +
+ + +

+ define_sphere(center, radius) + +

+ + +
+ +

Definition to define a sphere.

+ + +

Parameters:

+
    +
  • + center + – +
    +
         Center of a sphere in X,Y,Z.
    +
    +
    +
  • +
  • + radius + – +
    +
         Radius of a sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sphere ( ndarray +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_sphere(center, radius):
+    """
+    Definition to define a sphere.
+
+    Parameters
+    ----------
+    center     : ndarray
+                 Center of a sphere in X,Y,Z.
+    radius     : float
+                 Radius of a sphere.
+
+    Returns
+    ----------
+    sphere     : ndarray
+                 Single variable packed form.
+    """
+    sphere = np.array(
+        [center[0], center[1], center[2], radius], dtype=np.float64)
+    return sphere
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : list
+                  First point in X,Y,Z.
+    point2      : list
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : float
+                  Distance in between given two points.
+    """
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = np.sqrt(np.sum((point1-point2)**2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))
+    return distance
+
+
+
+ +
+ +
+ + +

+ find_nearest_points(ray0, ray1) + +

+ + +
+ +

Find the nearest points on given rays with respect to the other ray.

+ + +

Parameters:

+
    +
  • + ray0 + – +
    +
         A ray.
    +
    +
    +
  • +
  • + ray1 + – +
    +
         A ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +c0 ( ndarray +) – +
    +

    Closest point on ray0.

    +
    +
  • +
  • +c1 ( ndarray +) – +
    +

    Closest point on ray1.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def find_nearest_points(ray0, ray1):
+    """
+    Find the nearest points on given rays with respect to the other ray.
+
+    Parameters
+    ----------
+    ray0       : ndarray
+                 A ray.
+    ray1       : ndarray
+                 A ray.
+
+    Returns
+    ----------
+    c0         : ndarray
+                 Closest point on ray0.
+    c1         : ndarray
+                 Closest point on ray1.
+    """
+    p0 = ray0[0].reshape(3,)
+    d0 = ray0[1].reshape(3,)
+    p1 = ray1[0].reshape(3,)
+    d1 = ray1[1].reshape(3,)
+    n = np.cross(d0, d1)
+    if np.all(n) == 0:
+        point, distances = calculate_intersection_of_two_rays(ray0, ray1)
+        c0 = c1 = point
+    else:
+        n0 = np.cross(d0, n)
+        n1 = np.cross(d1, n)
+        c0 = p0+(np.dot((p1-p0), n1)/np.dot(d0, n1))*d0
+        c1 = p1+(np.dot((p0-p1), n0)/np.dot(d1, n0))*d1
+    return c0, c1
+
+
+
+ +
+ +
+ + +

+ get_cylinder_normal(point, cylinder) + +

+ + +
+ + + +

Parameters:

+
    +
  • + point + – +
    +
            Point on a cylinder defined in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( ndarray +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_cylinder_normal(point, cylinder):
+    """
+    Parameters
+    ----------
+    point         : ndarray
+                    Point on a cylinder defined in X,Y,Z.
+
+    Returns
+    ----------
+    normal_vector : ndarray
+                    Normal vector.
+    """
+    cylinder_ray = create_ray_from_two_points(cylinder[0:3], cylinder[4:7])
+    closest_point = closest_point_to_a_ray(
+        point,
+        cylinder_ray
+    )
+    normal_vector = create_ray_from_two_points(closest_point, point)
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_sphere_normal(point, sphere) + +

+ + +
+ +

Definition to get a normal of a point on a given sphere.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Point on sphere in X,Y,Z.
    +
    +
    +
  • +
  • + sphere + – +
    +
            Center defined in X,Y,Z and radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( ndarray +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_sphere_normal(point, sphere):
+    """
+    Definition to get a normal of a point on a given sphere.
+
+    Parameters
+    ----------
+    point         : ndarray
+                    Point on sphere in X,Y,Z.
+    sphere        : ndarray
+                    Center defined in X,Y,Z and radius.
+
+    Returns
+    ----------
+    normal_vector : ndarray
+                    Normal vector.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    normal_vector = create_ray_from_two_points(point, sphere[0:3])
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_triangle_normal(triangle, triangle_center=None) + +

+ + +
+ +

Definition to calculate surface normal of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
    +
    +
    +
  • +
  • + triangle_center + (ndarray, default: + None +) + – +
    +
              Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_triangle_normal(triangle, triangle_center=None):
+    """
+    Definition to calculate surface normal of a triangle.
+
+    Parameters
+    ----------
+    triangle        : ndarray
+                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
+    triangle_center : ndarray
+                      Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
+
+    Returns
+    ----------
+    normal          : ndarray
+                      Surface normal at the point of intersection.
+    """
+    triangle = np.asarray(triangle)
+    if len(triangle.shape) == 2:
+        triangle = triangle.reshape((1, 3, 3))
+    normal = np.zeros((triangle.shape[0], 2, 3))
+    direction = np.cross(
+        triangle[:, 0]-triangle[:, 1], triangle[:, 2]-triangle[:, 1])
+    if type(triangle_center) == type(None):
+        normal[:, 0] = center_of_triangle(triangle)
+    else:
+        normal[:, 0] = triangle_center
+    normal[:, 1] = direction/np.sum(direction, axis=1)[0]
+    if normal.shape[0] == 1:
+        normal = normal.reshape((2, 3))
+    return normal
+
+
+
+ +
+ +
+ + +

+ intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=1e-08, iter_no_limit=100000) + +

+ + +
+ +

Definition to intersect a ray with a parametric surface.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                      Ray.
    +
    +
    +
  • +
  • + parametric_surface + – +
    +
                      Parameters of the surfaces.
    +
    +
    +
  • +
  • + surface_function + – +
    +
                      Function to evaluate a point against a surface.
    +
    +
    +
  • +
  • + surface_normal_function + (function) + – +
    +
                      Function to calculate surface normal for a given point on a surface.
    +
    +
    +
  • +
  • + target_error + – +
    +
                      Target error that defines the precision.
    +
    +
    +
  • +
  • + iter_no_limit + – +
    +
                      Maximum number of iterations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Propagation distance.

    +
    +
  • +
  • +normal ( ndarray +) – +
    +

    Ray that defines a surface normal for the intersection.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=0.00000001, iter_no_limit=100000):
+    """
+    Definition to intersect a ray with a parametric surface.
+
+    Parameters
+    ----------
+    ray                     : ndarray
+                              Ray.
+    parametric_surface      : ndarray
+                              Parameters of the surfaces.
+    surface_function        : function
+                              Function to evaluate a point against a surface.
+    surface_normal_function : function
+                              Function to calculate surface normal for a given point on a surface.
+    target_error            : float
+                              Target error that defines the precision.  
+    iter_no_limit           : int
+                              Maximum number of iterations.
+
+    Returns
+    ----------
+    distance                : float
+                              Propagation distance.
+    normal                  : ndarray
+                              Ray that defines a surface normal for the intersection.
+    """
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    error = [150, 100]
+    distance = [0, 0.1]
+    iter_no = 0
+    while np.abs(np.max(np.asarray(error[1]))) > target_error:
+        error[1], point = intersection_kernel_for_parametric_surfaces(
+            distance[1],
+            ray,
+            parametric_surface,
+            surface_function
+        )
+        distance, error = propagate_parametric_intersection_error(
+            distance,
+            error
+        )
+        iter_no += 1
+        if iter_no > iter_no_limit:
+            return False, False
+        if np.isnan(np.sum(point)):
+            return False, False
+    normal = surface_normal_function(
+        point,
+        parametric_surface
+    )
+    return distance[1], normal
+
+
+
+ +
+ +
+ + +

+ intersect_w_circle(ray, circle) + +

+ + +
+ +

Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + circle + – +
    +
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_circle(ray, circle):
+    """
+    Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.
+
+    Parameters
+    ----------
+    ray          : ndarray
+                   A vector/ray.
+    circle       : list
+                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, circle[0])
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])
+    distance[np.nonzero(distance_to_center > circle[2])] = 0
+    if len(ray.shape) == 2:
+        normal = normal.reshape((2, 3))
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_cylinder(ray, cylinder) + +

+ + +
+ +

Definition to intersect a ray with a cylinder.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray definition.
    +
    +
    +
  • +
  • + cylinder + – +
    +
         A cylinder defined with a center in XYZ and radius of curvature.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    A ray defining surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Total optical propagation distance.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_cylinder(ray, cylinder):
+    """
+    Definition to intersect a ray with a cylinder.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray definition.
+    cylinder   : ndarray
+                 A cylinder defined with a center in XYZ and radius of curvature.
+
+    Returns
+    ----------
+    normal     : ndarray
+                 A ray defining surface normal at the point of intersection.
+    distance   : float
+                 Total optical propagation distance.
+    """
+    distance, normal = intersect_parametric(
+        ray,
+        cylinder,
+        cylinder_function,
+        get_cylinder_normal
+    )
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_sphere(ray, sphere) + +

+ + +
+ +

Definition to intersect a ray with a sphere.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray definition.
    +
    +
    +
  • +
  • + sphere + – +
    +
         A sphere defined with a center in XYZ and radius of curvature.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    A ray defining surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Total optical propagation distance.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_sphere(ray, sphere):
+    """
+    Definition to intersect a ray with a sphere.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray definition.
+    sphere     : ndarray
+                 A sphere defined with a center in XYZ and radius of curvature.
+
+    Returns
+    ----------
+    normal     : ndarray
+                 A ray defining surface normal at the point of intersection.
+    distance   : float
+                 Total optical propagation distance.
+    """
+    distance, normal = intersect_parametric(
+        ray,
+        sphere,
+        sphere_function,
+        get_sphere_normal
+    )
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface(ray, points) + +

+ + +
+ +

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + points + – +
    +
           Set of points in X,Y and Z to define a planar surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
def intersect_w_surface(ray, points):
+    """
+    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html
+
+    Parameters
+    ----------
+    ray          : ndarray
+                   A vector/ray.
+    points       : ndarray
+                   Set of points in X,Y and Z to define a planar surface.
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between starting point of a ray with it's intersection with a planar surface.
+    """
+    points = np.asarray(points)
+    normal = get_triangle_normal(points)
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    if len(points) == 2:
+        points = points.reshape((1, 3, 3))
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    f = normal[:, 0]-ray[:, 0]
+    distance = np.dot(normal[:, 1], f.T)/np.dot(normal[:, 1], ray[:, 1].T)
+    n = np.int64(np.amax(np.array([ray.shape[0], normal.shape[0]])))
+    normal = np.zeros((n, 2, 3))
+    normal[:, 0] = ray[:, 0]+distance.T*ray[:, 1]
+    distance = np.abs(distance)
+    if normal.shape[0] == 1:
+        normal = normal.reshape((2, 3))
+        distance = distance.reshape((1))
+    if distance.shape[0] == 1 and len(distance.shape) > 1:
+        distance = distance.reshape((distance.shape[1]))
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle(ray, triangle) + +

+ + +
+ +

Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_triangle(ray, triangle):
+    """
+    Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, triangle)
+    if is_it_on_triangle(normal[0], triangle[0], triangle[1], triangle[2]) == False:
+        return 0, 0
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function) + +

+ + +
+ +

Definition for the intersection kernel when dealing with parametric surfaces.

+ + +

Parameters:

+
    +
  • + distance + – +
    +
                 Distance.
    +
    +
    +
  • +
  • + ray + – +
    +
                 Ray.
    +
    +
    +
  • +
  • + parametric_surface + (ndarray) + – +
    +
                 Array that defines a parametric surface.
    +
    +
    +
  • +
  • + surface_function + – +
    +
                 Function to evaluate a point against a parametric surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point ( ndarray +) – +
    +

    Location in X,Y,Z after propagation.

    +
    +
  • +
  • +error ( float +) – +
    +

    Error.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function):
+    """
+    Definition for the intersection kernel when dealing with parametric surfaces.
+
+    Parameters
+    ----------
+    distance           : float
+                         Distance.
+    ray                : ndarray
+                         Ray.
+    parametric_surface : ndarray
+                         Array that defines a parametric surface.
+    surface_function   : ndarray
+                         Function to evaluate a point against a parametric surface.
+
+    Returns
+    ----------
+    point              : ndarray
+                         Location in X,Y,Z after propagation.
+    error              : float
+                         Error.
+    """
+    new_ray = propagate_a_ray(ray, distance)
+    if len(new_ray) == 2:
+        new_ray = new_ray.reshape((1, 2, 3))
+    point = new_ray[:, 0]
+    error = surface_function(point, parametric_surface)
+    return error, point
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle(pointtocheck, point0, point1, point2) + +

+ + +
+ +

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.

+ + +

Parameters:

+
    +
  • + pointtocheck + – +
    +
            Point to check.
    +
    +
    +
  • +
  • + point0 + – +
    +
            First point of a triangle.
    +
    +
    +
  • +
  • + point1 + – +
    +
            Second point of a triangle.
    +
    +
    +
  • +
  • + point2 + – +
    +
            Third point of a triangle.
    +
    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def is_it_on_triangle(pointtocheck, point0, point1, point2):
+    """
+    Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.
+
+    Parameters
+    ----------
+    pointtocheck  : list
+                    Point to check.
+    point0        : list
+                    First point of a triangle.
+    point1        : list
+                    Second point of a triangle.
+    point2        : list
+                    Third point of a triangle.
+    """
+    # point0, point1 and point2 are the corners of the triangle.
+    pointtocheck = np.asarray(pointtocheck).reshape(3)
+    point0 = np.asarray(point0)
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    side0 = same_side(pointtocheck, point0, point1, point2)
+    side1 = same_side(pointtocheck, point1, point0, point2)
+    side2 = same_side(pointtocheck, point2, point0, point1)
+    if side0 == True and side1 == True and side2 == True:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ point_to_ray_distance(point, ray_point_0, ray_point_1) + +

+ + +
+ +

Definition to find point's closest distance to a line represented with two points.

+ + +

Parameters:

+
    +
  • + point + – +
    +
          Point to be tested.
    +
    +
    +
  • +
  • + ray_point_0 + (ndarray) + – +
    +
          First point to represent a line.
    +
    +
    +
  • +
  • + ray_point_1 + (ndarray) + – +
    +
          Second point to represent a line.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Calculated distance.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def point_to_ray_distance(point, ray_point_0, ray_point_1):
+    """
+    Definition to find point's closest distance to a line represented with two points.
+
+    Parameters
+    ----------
+    point       : ndarray
+                  Point to be tested.
+    ray_point_0 : ndarray
+                  First point to represent a line.
+    ray_point_1 : ndarray
+                  Second point to represent a line.
+
+    Returns
+    ----------
+    distance    : float
+                  Calculated distance.
+    """
+    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))
+                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)
+    return distance
+
+
+
+ +
+ +
+ + +

+ propagate_a_ray(ray, distance) + +

+ + +
+ +

Definition to propagate a ray at a certain given distance.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray.
    +
    +
    +
  • +
  • + distance + – +
    +
         Distance.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_ray ( ndarray +) – +
    +

    Propagated ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def propagate_a_ray(ray, distance):
+    """
+    Definition to propagate a ray at a certain given distance.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray.
+    distance   : float
+                 Distance.
+
+    Returns
+    ----------
+    new_ray    : ndarray
+                 Propagated ray.
+    """
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    new_ray = np.copy(ray)
+    new_ray[:, 0, 0] = distance*new_ray[:, 1, 0] + new_ray[:, 0, 0]
+    new_ray[:, 0, 1] = distance*new_ray[:, 1, 1] + new_ray[:, 0, 1]
+    new_ray[:, 0, 2] = distance*new_ray[:, 1, 2] + new_ray[:, 0, 2]
+    if new_ray.shape[0] == 1:
+        new_ray = new_ray.reshape((2, 3))
+    return new_ray
+
+
+
+ +
+ +
+ + +

+ propagate_parametric_intersection_error(distance, error) + +

+ + +
+ +

Definition to propagate the error in parametric intersection to find the next distance to try.

+ + +

Parameters:

+
    +
  • + distance + – +
    +
           List that contains the new and the old distance.
    +
    +
    +
  • +
  • + error + – +
    +
           List that contains the new and the old error.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( list +) – +
    +

    New distance.

    +
    +
  • +
  • +error ( list +) – +
    +

    New error.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def propagate_parametric_intersection_error(distance, error):
+    """
+    Definition to propagate the error in parametric intersection to find the next distance to try.
+
+    Parameters
+    ----------
+    distance     : list
+                   List that contains the new and the old distance.
+    error        : list
+                   List that contains the new and the old error.
+
+    Returns
+    ----------
+    distance     : list
+                   New distance.
+    error        : list
+                   New error.
+    """
+    new_distance = distance[1]-error[1] * \
+        (distance[1]-distance[0])/(error[1]-error[0])
+    distance[0] = distance[1]
+    distance[1] = np.abs(new_distance)
+    error[0] = error[1]
+    return distance, error
+
+
+
+ +
+ +
+ + +

+ reflect(input_ray, normal) + +

+ + +
+ +

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + input_ray + – +
    +
           A vector/ray (2x3). It can also be a list of rays (nx2x3).
    +
    +
    +
  • +
  • + normal + – +
    +
           A surface normal (2x3). It also be a list of normals (nx2x3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a reflected ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
def reflect(input_ray, normal):
+    """ 
+    Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+    Parameters
+    ----------
+    input_ray    : ndarray
+                   A vector/ray (2x3). It can also be a list of rays (nx2x3).
+    normal       : ndarray
+                   A surface normal (2x3). It also be a list of normals (nx2x3).
+
+    Returns
+    ----------
+    output_ray   : ndarray
+                   Array that contains starting points and cosines of a reflected ray.
+    """
+    input_ray = np.asarray(input_ray)
+    normal = np.asarray(normal)
+    if len(input_ray.shape) == 2:
+        input_ray = input_ray.reshape((1, 2, 3))
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    mu = 1
+    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2
+    a = mu * (input_ray[:, 1, 0]*normal[:, 1, 0]
+              + input_ray[:, 1, 1]*normal[:, 1, 1]
+              + input_ray[:, 1, 2]*normal[:, 1, 2]) / div
+    n = np.int64(np.amax(np.array([normal.shape[0], input_ray.shape[0]])))
+    output_ray = np.zeros((n, 2, 3))
+    output_ray[:, 0] = normal[:, 0]
+    output_ray[:, 1] = input_ray[:, 1]-2*a*normal[:, 1]
+    if output_ray.shape[0] == 1:
+        output_ray = output_ray.reshape((2, 3))
+    return output_ray
+
+
+
+ +
+ +
+ + +

+ rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
  • +rotx ( ndarray +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
  • +roty ( ndarray +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
  • +rotz ( ndarray +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : ndarray
+                   A point.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation
+    rotx         : ndarray
+                   Rotation matrix along X axis.
+    roty         : ndarray
+                   Rotation matrix along Y axis.
+    rotz         : ndarray
+                   Rotation matrix along Z axis.
+    """
+    point = np.asarray(point)
+    point -= np.asarray(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))
+    result += np.asarray(origin)
+    result += np.asarray(offset)
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate points.

+ + +

Parameters:

+
    +
  • + points + – +
    +
           Points.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate points.
+
+    Parameters
+    ----------
+    points       : ndarray
+                   Points.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation   
+    """
+    points = np.asarray(points)
+    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:
+        result = np.array(offset) + points
+        return result
+    points -= np.array(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T
+    result += np.array(origin)
+    result += np.array(offset)
+    return result
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = np.subtract(b, a)
+    p1a = np.subtract(p1, a)
+    p2a = np.subtract(p2, a)
+    cp1 = np.cross(ba, p1a)
+    cp2 = np.cross(ba, p2a)
+    test = np.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ sphere_function(point, sphere) + +

+ + +
+ +

Definition of a sphere function. Evaluate a point against a sphere function.

+ + +

Parameters:

+
    +
  • + sphere + – +
    +
         Sphere parameters, XYZ center and radius.
    +
    +
    +
  • +
  • + point + – +
    +
         Point in XYZ.
    +
    +
    +
  • +
+ + +
+ Return +

result : float + Result of the evaluation. Zero if point is on sphere.

+
+
+ Source code in odak/raytracing/primitives.py +
def sphere_function(point, sphere):
+    """
+    Definition of a sphere function. Evaluate a point against a sphere function.
+
+    Parameters
+    ----------
+    sphere     : ndarray
+                 Sphere parameters, XYZ center and radius.
+    point      : ndarray
+                 Point in XYZ.
+
+    Return
+    ----------
+    result     : float
+                 Result of the evaluation. Zero if point is on sphere.
+    """
+    point = np.asarray(point)
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    result = (point[:, 0]-sphere[0])**2 + (point[:, 1]-sphere[1]
+                                           )**2 + (point[:, 2]-sphere[2])**2 - sphere[3]**2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_cylinder_normal(point, cylinder) + +

+ + +
+ + + +

Parameters:

+
    +
  • + point + – +
    +
            Point on a cylinder defined in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( ndarray +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_cylinder_normal(point, cylinder):
+    """
+    Parameters
+    ----------
+    point         : ndarray
+                    Point on a cylinder defined in X,Y,Z.
+
+    Returns
+    ----------
+    normal_vector : ndarray
+                    Normal vector.
+    """
+    cylinder_ray = create_ray_from_two_points(cylinder[0:3], cylinder[4:7])
+    closest_point = closest_point_to_a_ray(
+        point,
+        cylinder_ray
+    )
+    normal_vector = create_ray_from_two_points(closest_point, point)
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_sphere_normal(point, sphere) + +

+ + +
+ +

Definition to get a normal of a point on a given sphere.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Point on sphere in X,Y,Z.
    +
    +
    +
  • +
  • + sphere + – +
    +
            Center defined in X,Y,Z and radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal_vector ( ndarray +) – +
    +

    Normal vector.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_sphere_normal(point, sphere):
+    """
+    Definition to get a normal of a point on a given sphere.
+
+    Parameters
+    ----------
+    point         : ndarray
+                    Point on sphere in X,Y,Z.
+    sphere        : ndarray
+                    Center defined in X,Y,Z and radius.
+
+    Returns
+    ----------
+    normal_vector : ndarray
+                    Normal vector.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    normal_vector = create_ray_from_two_points(point, sphere[0:3])
+    return normal_vector
+
+
+
+ +
+ +
+ + +

+ get_triangle_normal(triangle, triangle_center=None) + +

+ + +
+ +

Definition to calculate surface normal of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
    +
    +
    +
  • +
  • + triangle_center + (ndarray, default: + None +) + – +
    +
              Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def get_triangle_normal(triangle, triangle_center=None):
+    """
+    Definition to calculate surface normal of a triangle.
+
+    Parameters
+    ----------
+    triangle        : ndarray
+                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).
+    triangle_center : ndarray
+                      Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.
+
+    Returns
+    ----------
+    normal          : ndarray
+                      Surface normal at the point of intersection.
+    """
+    triangle = np.asarray(triangle)
+    if len(triangle.shape) == 2:
+        triangle = triangle.reshape((1, 3, 3))
+    normal = np.zeros((triangle.shape[0], 2, 3))
+    direction = np.cross(
+        triangle[:, 0]-triangle[:, 1], triangle[:, 2]-triangle[:, 1])
+    if type(triangle_center) == type(None):
+        normal[:, 0] = center_of_triangle(triangle)
+    else:
+        normal[:, 0] = triangle_center
+    normal[:, 1] = direction/np.sum(direction, axis=1)[0]
+    if normal.shape[0] == 1:
+        normal = normal.reshape((2, 3))
+    return normal
+
+
+
+ +
+ +
+ + +

+ intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=1e-08, iter_no_limit=100000) + +

+ + +
+ +

Definition to intersect a ray with a parametric surface.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
                      Ray.
    +
    +
    +
  • +
  • + parametric_surface + – +
    +
                      Parameters of the surfaces.
    +
    +
    +
  • +
  • + surface_function + – +
    +
                      Function to evaluate a point against a surface.
    +
    +
    +
  • +
  • + surface_normal_function + (function) + – +
    +
                      Function to calculate surface normal for a given point on a surface.
    +
    +
    +
  • +
  • + target_error + – +
    +
                      Target error that defines the precision.
    +
    +
    +
  • +
  • + iter_no_limit + – +
    +
                      Maximum number of iterations.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Propagation distance.

    +
    +
  • +
  • +normal ( ndarray +) – +
    +

    Ray that defines a surface normal for the intersection.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=0.00000001, iter_no_limit=100000):
+    """
+    Definition to intersect a ray with a parametric surface.
+
+    Parameters
+    ----------
+    ray                     : ndarray
+                              Ray.
+    parametric_surface      : ndarray
+                              Parameters of the surfaces.
+    surface_function        : function
+                              Function to evaluate a point against a surface.
+    surface_normal_function : function
+                              Function to calculate surface normal for a given point on a surface.
+    target_error            : float
+                              Target error that defines the precision.  
+    iter_no_limit           : int
+                              Maximum number of iterations.
+
+    Returns
+    ----------
+    distance                : float
+                              Propagation distance.
+    normal                  : ndarray
+                              Ray that defines a surface normal for the intersection.
+    """
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    error = [150, 100]
+    distance = [0, 0.1]
+    iter_no = 0
+    while np.abs(np.max(np.asarray(error[1]))) > target_error:
+        error[1], point = intersection_kernel_for_parametric_surfaces(
+            distance[1],
+            ray,
+            parametric_surface,
+            surface_function
+        )
+        distance, error = propagate_parametric_intersection_error(
+            distance,
+            error
+        )
+        iter_no += 1
+        if iter_no > iter_no_limit:
+            return False, False
+        if np.isnan(np.sum(point)):
+            return False, False
+    normal = surface_normal_function(
+        point,
+        parametric_surface
+    )
+    return distance[1], normal
+
+
+
+ +
+ +
+ + +

+ intersect_w_circle(ray, circle) + +

+ + +
+ +

Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + circle + – +
    +
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_circle(ray, circle):
+    """
+    Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.
+
+    Parameters
+    ----------
+    ray          : ndarray
+                   A vector/ray.
+    circle       : list
+                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, circle[0])
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])
+    distance[np.nonzero(distance_to_center > circle[2])] = 0
+    if len(ray.shape) == 2:
+        normal = normal.reshape((2, 3))
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_cylinder(ray, cylinder) + +

+ + +
+ +

Definition to intersect a ray with a cylinder.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray definition.
    +
    +
    +
  • +
  • + cylinder + – +
    +
         A cylinder defined with a center in XYZ and radius of curvature.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    A ray defining surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Total optical propagation distance.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_cylinder(ray, cylinder):
+    """
+    Definition to intersect a ray with a cylinder.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray definition.
+    cylinder   : ndarray
+                 A cylinder defined with a center in XYZ and radius of curvature.
+
+    Returns
+    ----------
+    normal     : ndarray
+                 A ray defining surface normal at the point of intersection.
+    distance   : float
+                 Total optical propagation distance.
+    """
+    distance, normal = intersect_parametric(
+        ray,
+        cylinder,
+        cylinder_function,
+        get_cylinder_normal
+    )
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_sphere(ray, sphere) + +

+ + +
+ +

Definition to intersect a ray with a sphere.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray definition.
    +
    +
    +
  • +
  • + sphere + – +
    +
         A sphere defined with a center in XYZ and radius of curvature.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    A ray defining surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Total optical propagation distance.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_sphere(ray, sphere):
+    """
+    Definition to intersect a ray with a sphere.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray definition.
+    sphere     : ndarray
+                 A sphere defined with a center in XYZ and radius of curvature.
+
+    Returns
+    ----------
+    normal     : ndarray
+                 A ray defining surface normal at the point of intersection.
+    distance   : float
+                 Total optical propagation distance.
+    """
+    distance, normal = intersect_parametric(
+        ray,
+        sphere,
+        sphere_function,
+        get_sphere_normal
+    )
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_surface(ray, points) + +

+ + +
+ +

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + points + – +
    +
           Set of points in X,Y and Z to define a planar surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between starting point of a ray with it's intersection with a planar surface.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
def intersect_w_surface(ray, points):
+    """
+    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html
+
+    Parameters
+    ----------
+    ray          : ndarray
+                   A vector/ray.
+    points       : ndarray
+                   Set of points in X,Y and Z to define a planar surface.
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between starting point of a ray with it's intersection with a planar surface.
+    """
+    points = np.asarray(points)
+    normal = get_triangle_normal(points)
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    if len(points) == 2:
+        points = points.reshape((1, 3, 3))
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    f = normal[:, 0]-ray[:, 0]
+    distance = np.dot(normal[:, 1], f.T)/np.dot(normal[:, 1], ray[:, 1].T)
+    n = np.int64(np.amax(np.array([ray.shape[0], normal.shape[0]])))
+    normal = np.zeros((n, 2, 3))
+    normal[:, 0] = ray[:, 0]+distance.T*ray[:, 1]
+    distance = np.abs(distance)
+    if normal.shape[0] == 1:
+        normal = normal.reshape((2, 3))
+        distance = distance.reshape((1))
+    if distance.shape[0] == 1 and len(distance.shape) > 1:
+        distance = distance.reshape((distance.shape[1]))
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersect_w_triangle(ray, triangle) + +

+ + +
+ +

Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
    +
    +
    +
  • +
  • + triangle + – +
    +
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +normal ( ndarray +) – +
    +

    Surface normal at the point of intersection.

    +
    +
  • +
  • +distance ( float +) – +
    +

    Distance in between a starting point of a ray and the intersection point with a given triangle.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersect_w_triangle(ray, triangle):
+    """
+    Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.
+
+    Parameters
+    ----------
+    ray          : torch.tensor
+                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).
+    triangle     : torch.tensor
+                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).
+
+    Returns
+    ----------
+    normal       : ndarray
+                   Surface normal at the point of intersection.
+    distance     : float
+                   Distance in between a starting point of a ray and the intersection point with a given triangle.
+    """
+    normal, distance = intersect_w_surface(ray, triangle)
+    if is_it_on_triangle(normal[0], triangle[0], triangle[1], triangle[2]) == False:
+        return 0, 0
+    return normal, distance
+
+
+
+ +
+ +
+ + +

+ intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function) + +

+ + +
+ +

Definition for the intersection kernel when dealing with parametric surfaces.

+ + +

Parameters:

+
    +
  • + distance + – +
    +
                 Distance.
    +
    +
    +
  • +
  • + ray + – +
    +
                 Ray.
    +
    +
    +
  • +
  • + parametric_surface + (ndarray) + – +
    +
                 Array that defines a parametric surface.
    +
    +
    +
  • +
  • + surface_function + – +
    +
                 Function to evaluate a point against a parametric surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point ( ndarray +) – +
    +

    Location in X,Y,Z after propagation.

    +
    +
  • +
  • +error ( float +) – +
    +

    Error.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function):
+    """
+    Definition for the intersection kernel when dealing with parametric surfaces.
+
+    Parameters
+    ----------
+    distance           : float
+                         Distance.
+    ray                : ndarray
+                         Ray.
+    parametric_surface : ndarray
+                         Array that defines a parametric surface.
+    surface_function   : ndarray
+                         Function to evaluate a point against a parametric surface.
+
+    Returns
+    ----------
+    point              : ndarray
+                         Location in X,Y,Z after propagation.
+    error              : float
+                         Error.
+    """
+    new_ray = propagate_a_ray(ray, distance)
+    if len(new_ray) == 2:
+        new_ray = new_ray.reshape((1, 2, 3))
+    point = new_ray[:, 0]
+    error = surface_function(point, parametric_surface)
+    return error, point
+
+
+
+ +
+ +
+ + +

+ propagate_parametric_intersection_error(distance, error) + +

+ + +
+ +

Definition to propagate the error in parametric intersection to find the next distance to try.

+ + +

Parameters:

+
    +
  • + distance + – +
    +
           List that contains the new and the old distance.
    +
    +
    +
  • +
  • + error + – +
    +
           List that contains the new and the old error.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( list +) – +
    +

    New distance.

    +
    +
  • +
  • +error ( list +) – +
    +

    New error.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
def propagate_parametric_intersection_error(distance, error):
+    """
+    Definition to propagate the error in parametric intersection to find the next distance to try.
+
+    Parameters
+    ----------
+    distance     : list
+                   List that contains the new and the old distance.
+    error        : list
+                   List that contains the new and the old error.
+
+    Returns
+    ----------
+    distance     : list
+                   New distance.
+    error        : list
+                   New error.
+    """
+    new_distance = distance[1]-error[1] * \
+        (distance[1]-distance[0])/(error[1]-error[0])
+    distance[0] = distance[1]
+    distance[1] = np.abs(new_distance)
+    error[0] = error[1]
+    return distance, error
+
+
+
+ +
+ +
+ + +

+ reflect(input_ray, normal) + +

+ + +
+ +

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.

+ + +

Parameters:

+
    +
  • + input_ray + – +
    +
           A vector/ray (2x3). It can also be a list of rays (nx2x3).
    +
    +
    +
  • +
  • + normal + – +
    +
           A surface normal (2x3). It also be a list of normals (nx2x3).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +output_ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a reflected ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/boundary.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
def reflect(input_ray, normal):
+    """ 
+    Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, "General Ray-Tracing Procedure", 1961.
+
+    Parameters
+    ----------
+    input_ray    : ndarray
+                   A vector/ray (2x3). It can also be a list of rays (nx2x3).
+    normal       : ndarray
+                   A surface normal (2x3). It also be a list of normals (nx2x3).
+
+    Returns
+    ----------
+    output_ray   : ndarray
+                   Array that contains starting points and cosines of a reflected ray.
+    """
+    input_ray = np.asarray(input_ray)
+    normal = np.asarray(normal)
+    if len(input_ray.shape) == 2:
+        input_ray = input_ray.reshape((1, 2, 3))
+    if len(normal.shape) == 2:
+        normal = normal.reshape((1, 2, 3))
+    mu = 1
+    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2
+    a = mu * (input_ray[:, 1, 0]*normal[:, 1, 0]
+              + input_ray[:, 1, 1]*normal[:, 1, 1]
+              + input_ray[:, 1, 2]*normal[:, 1, 2]) / div
+    n = np.int64(np.amax(np.array([normal.shape[0], input_ray.shape[0]])))
+    output_ray = np.zeros((n, 2, 3))
+    output_ray[:, 0] = normal[:, 0]
+    output_ray[:, 1] = input_ray[:, 1]-2*a*normal[:, 1]
+    if output_ray.shape[0] == 1:
+        output_ray = output_ray.reshape((2, 3))
+    return output_ray
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ bring_plane_to_origin(point, plane, shape=[10.0, 10.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0], mode='XYZ') + +

+ + +
+ +

Definition to bring points back to reference origin with respect to a plane.

+ + +

Parameters:

+
    +
  • + point + – +
    +
                 Point(s) to be tested.
    +
    +
    +
  • +
  • + shape + – +
    +
                 Dimensions of the rectangle along X and Y axes.
    +
    +
    +
  • +
  • + center + – +
    +
                 Center of the rectangle.
    +
    +
    +
  • +
  • + angles + – +
    +
                 Rotation angle of the rectangle.
    +
    +
    +
  • +
  • + mode + – +
    +
                 Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +transformed_points ( ndarray +) – +
    +

    Point(s) that are brought back to reference origin with respect to given plane.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
def bring_plane_to_origin(point, plane, shape=[10., 10.], center=[0., 0., 0.], angles=[0., 0., 0.], mode='XYZ'):
+    """
+    Definition to bring points back to reference origin with respect to a plane.
+
+    Parameters
+    ----------
+    point              : ndarray
+                         Point(s) to be tested.
+    shape              : list
+                         Dimensions of the rectangle along X and Y axes.
+    center             : list
+                         Center of the rectangle.
+    angles             : list
+                         Rotation angle of the rectangle.
+    mode               : str
+                         Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.
+
+    Returns
+    ----------
+    transformed_points : ndarray
+                         Point(s) that are brought back to reference origin with respect to given plane.
+    """
+    if point.shape[0] == 3:
+        point = point.reshape((1, 3))
+    reverse_mode = mode[::-1]
+    angles = [-angles[0], -angles[1], -angles[2]]
+    center = np.asarray(center).reshape((1, 3))
+    transformed_points = point-center
+    transformed_points = rotate_points(
+        transformed_points,
+        angles=angles,
+        mode=reverse_mode,
+    )
+    if transformed_points.shape[0] == 1:
+        transformed_points = transformed_points.reshape((3,))
+    return transformed_points
+
+
+
+ +
+ +
+ + +

+ center_of_triangle(triangle) + +

+ + +
+ +

Definition to calculate center of a triangle.

+ + +

Parameters:

+
    +
  • + triangle + – +
    +
            An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).
    +
    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
def center_of_triangle(triangle):
+    """
+    Definition to calculate center of a triangle.
+
+    Parameters
+    ----------
+    triangle      : ndarray
+                    An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).
+    """
+    if len(triangle.shape) == 2:
+        triangle = triangle.reshape((1, 3, 3))
+    center = np.mean(triangle, axis=1)
+    return center
+
+
+
+ +
+ +
+ + +

+ cylinder_function(point, cylinder) + +

+ + +
+ +

Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html

+ + +

Parameters:

+
    +
  • + cylinder + – +
    +
         Cylinder parameters, XYZ center and radius.
    +
    +
    +
  • +
  • + point + – +
    +
         Point in XYZ.
    +
    +
    +
  • +
+ + +
+ Return +

result : float + Result of the evaluation. Zero if point is on sphere.

+
+
+ Source code in odak/raytracing/primitives.py +
def cylinder_function(point, cylinder):
+    """
+    Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
+
+    Parameters
+    ----------
+    cylinder   : ndarray
+                 Cylinder parameters, XYZ center and radius.
+    point      : ndarray
+                 Point in XYZ.
+
+    Return
+    ----------
+    result     : float
+                 Result of the evaluation. Zero if point is on sphere.
+    """
+    point = np.asarray(point)
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    distance = point_to_ray_distance(
+        point,
+        np.array([cylinder[0], cylinder[1], cylinder[2]], dtype=np.float64),
+        np.array([cylinder[4], cylinder[5], cylinder[6]], dtype=np.float64)
+    )
+    r = cylinder[3]
+    result = distance - r ** 2
+    return result
+
+
+
+ +
+ +
+ + +

+ define_circle(center, radius, angles) + +

+ + +
+ +

Definition to describe a circle in a single variable packed form.

+ + +

Parameters:

+
    +
  • + center + – +
    +
      Center of a circle to be defined.
    +
    +
    +
  • +
  • + radius + – +
    +
      Radius of a circle to be defined.
    +
    +
    +
  • +
  • + angles + – +
    +
      Angular tilt of a circle.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +circle ( list +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_circle(center, radius, angles):
+    """
+    Definition to describe a circle in a single variable packed form.
+
+    Parameters
+    ----------
+    center  : float
+              Center of a circle to be defined.
+    radius  : float
+              Radius of a circle to be defined.
+    angles  : float
+              Angular tilt of a circle.
+
+    Returns
+    ----------
+    circle  : list
+              Single variable packed form.
+    """
+    points = define_plane(center, angles=angles)
+    circle = [
+        points,
+        center,
+        radius
+    ]
+    return circle
+
+
+
+ +
+ +
+ + +

+ define_cylinder(center, radius, rotation=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to define a cylinder

+ + +

Parameters:

+
    +
  • + center + – +
    +
         Center of a cylinder in X,Y,Z.
    +
    +
    +
  • +
  • + radius + – +
    +
         Radius of a cylinder along X axis.
    +
    +
    +
  • +
  • + rotation + – +
    +
         Direction angles in degrees for the orientation of a cylinder.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cylinder ( ndarray +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_cylinder(center, radius, rotation=[0., 0., 0.]):
+    """
+    Definition to define a cylinder
+
+    Parameters
+    ----------
+    center     : ndarray
+                 Center of a cylinder in X,Y,Z.
+    radius     : float
+                 Radius of a cylinder along X axis.
+    rotation   : list
+                 Direction angles in degrees for the orientation of a cylinder.
+
+    Returns
+    ----------
+    cylinder   : ndarray
+                 Single variable packed form.
+    """
+    cylinder_ray = create_ray_from_angles(
+        np.asarray(center), np.asarray(rotation))
+    cylinder = np.array(
+        [
+            center[0],
+            center[1],
+            center[2],
+            radius,
+            center[0]+cylinder_ray[1, 0],
+            center[1]+cylinder_ray[1, 1],
+            center[2]+cylinder_ray[1, 2]
+        ],
+        dtype=np.float64
+    )
+    return cylinder
+
+
+
+ +
+ +
+ + +

+ define_plane(point, angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point that is at the center of a plane.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +plane ( ndarray +) – +
    +

    Points defining plane.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def define_plane(point, angles=[0., 0., 0.]):
+    """ 
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    point        : ndarray
+                   A point that is at the center of a plane.
+    angles       : list
+                   Rotation angles in degrees.
+
+    Returns
+    ----------
+    plane        : ndarray
+                   Points defining plane.
+    """
+    plane = np.array([
+        [10., 10., 0.],
+        [0., 10., 0.],
+        [0.,  0., 0.]
+    ], dtype=np.float64)
+    point = np.asarray(point)
+    for i in range(0, plane.shape[0]):
+        plane[i], _, _, _ = rotate_point(plane[i], angles=angles)
+        plane[i] = plane[i]+point
+    return plane
+
+
+
+ +
+ +
+ + +

+ define_sphere(center, radius) + +

+ + +
+ +

Definition to define a sphere.

+ + +

Parameters:

+
    +
  • + center + – +
    +
         Center of a sphere in X,Y,Z.
    +
    +
    +
  • +
  • + radius + – +
    +
         Radius of a sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +sphere ( ndarray +) – +
    +

    Single variable packed form.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def define_sphere(center, radius):
+    """
+    Definition to define a sphere.
+
+    Parameters
+    ----------
+    center     : ndarray
+                 Center of a sphere in X,Y,Z.
+    radius     : float
+                 Radius of a sphere.
+
+    Returns
+    ----------
+    sphere     : ndarray
+                 Single variable packed form.
+    """
+    sphere = np.array(
+        [center[0], center[1], center[2], radius], dtype=np.float64)
+    return sphere
+
+
+
+ +
+ +
+ + +

+ is_it_on_triangle(pointtocheck, point0, point1, point2) + +

+ + +
+ +

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.

+ + +

Parameters:

+
    +
  • + pointtocheck + – +
    +
            Point to check.
    +
    +
    +
  • +
  • + point0 + – +
    +
            First point of a triangle.
    +
    +
    +
  • +
  • + point1 + – +
    +
            Second point of a triangle.
    +
    +
    +
  • +
  • + point2 + – +
    +
            Third point of a triangle.
    +
    +
    +
  • +
+ +
+ Source code in odak/raytracing/primitives.py +
def is_it_on_triangle(pointtocheck, point0, point1, point2):
+    """
+    Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.
+
+    Parameters
+    ----------
+    pointtocheck  : list
+                    Point to check.
+    point0        : list
+                    First point of a triangle.
+    point1        : list
+                    Second point of a triangle.
+    point2        : list
+                    Third point of a triangle.
+    """
+    # point0, point1 and point2 are the corners of the triangle.
+    pointtocheck = np.asarray(pointtocheck).reshape(3)
+    point0 = np.asarray(point0)
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    side0 = same_side(pointtocheck, point0, point1, point2)
+    side1 = same_side(pointtocheck, point1, point0, point2)
+    side2 = same_side(pointtocheck, point2, point0, point1)
+    if side0 == True and side1 == True and side2 == True:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ sphere_function(point, sphere) + +

+ + +
+ +

Definition of a sphere function. Evaluate a point against a sphere function.

+ + +

Parameters:

+
    +
  • + sphere + – +
    +
         Sphere parameters, XYZ center and radius.
    +
    +
    +
  • +
  • + point + – +
    +
         Point in XYZ.
    +
    +
    +
  • +
+ + +
+ Return +

result : float + Result of the evaluation. Zero if point is on sphere.

+
+
+ Source code in odak/raytracing/primitives.py +
def sphere_function(point, sphere):
+    """
+    Definition of a sphere function. Evaluate a point against a sphere function.
+
+    Parameters
+    ----------
+    sphere     : ndarray
+                 Sphere parameters, XYZ center and radius.
+    point      : ndarray
+                 Point in XYZ.
+
+    Return
+    ----------
+    result     : float
+                 Result of the evaluation. Zero if point is on sphere.
+    """
+    point = np.asarray(point)
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    result = (point[:, 0]-sphere[0])**2 + (point[:, 1]-sphere[1]
+                                           )**2 + (point[:, 2]-sphere[2])**2 - sphere[3]**2
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ calculate_intersection_of_two_rays(ray0, ray1) + +

+ + +
+ +

Definition to calculate the intersection of two rays.

+ + +

Parameters:

+
    +
  • + ray0 + – +
    +
         A ray.
    +
    +
    +
  • +
  • + ray1 + – +
    +
         A ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point ( ndarray +) – +
    +

    Point in X,Y,Z.

    +
    +
  • +
  • +distances ( ndarray +) – +
    +

    Distances.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def calculate_intersection_of_two_rays(ray0, ray1):
+    """
+    Definition to calculate the intersection of two rays.
+
+    Parameters
+    ----------
+    ray0       : ndarray
+                 A ray.
+    ray1       : ndarray
+                 A ray.
+
+    Returns
+    ----------
+    point      : ndarray
+                 Point in X,Y,Z.
+    distances  : ndarray
+                 Distances.
+    """
+    A = np.array([
+        [float(ray0[1][0]), float(ray1[1][0])],
+        [float(ray0[1][1]), float(ray1[1][1])],
+        [float(ray0[1][2]), float(ray1[1][2])]
+    ])
+    B = np.array([
+        ray0[0][0]-ray1[0][0],
+        ray0[0][1]-ray1[0][1],
+        ray0[0][2]-ray1[0][2]
+    ])
+    distances = np.linalg.lstsq(A, B, rcond=None)[0]
+    if np.allclose(np.dot(A, distances), B) == False:
+        distances = np.array([0, 0])
+    distances = distances[np.argsort(-distances)]
+    point = propagate_a_ray(ray0, distances[0])[0]
+    return point, distances
+
+
+
+ +
+ +
+ + +

+ create_ray(x0y0z0, abg) + +

+ + +
+ +

Definition to create a ray.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray.
    +
    +
    +
  • +
  • + abg + – +
    +
           List that contaings angles in degrees with respect to the X,Y and Z axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def create_ray(x0y0z0, abg):
+    """
+    Definition to create a ray.
+
+    Parameters
+    ----------
+    x0y0z0       : list
+                   List that contains X,Y and Z start locations of a ray.
+    abg          : list
+                   List that contaings angles in degrees with respect to the X,Y and Z axes.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    # Due to Python 2 -> Python 3.
+    x0, y0, z0 = x0y0z0
+    alpha, beta, gamma = abg
+    # Create a vector with the given points and angles in each direction
+    point = np.array([x0, y0, z0], dtype=np.float64)
+    alpha = np.cos(np.radians(alpha))
+    beta = np.cos(np.radians(beta))
+    gamma = np.cos(np.radians(gamma))
+    # Cosines vector.
+    cosines = np.array([alpha, beta, gamma], dtype=np.float64)
+    ray = np.array([point, cosines], dtype=np.float64)
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_angles(point, angles, mode='XYZ') + +

+ + +
+ +

Definition to create a ray from a point and angles.

+ + +

Parameters:

+
    +
  • + point + – +
    +
         Point in X,Y and Z.
    +
    +
    +
  • +
  • + angles + – +
    +
         Angles with X,Y,Z axes in degrees. All zeros point Z axis.
    +
    +
    +
  • +
  • + mode + – +
    +
         Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def create_ray_from_angles(point, angles, mode='XYZ'):
+    """
+    Definition to create a ray from a point and angles.
+
+    Parameters
+    ----------
+    point      : ndarray
+                 Point in X,Y and Z.
+    angles     : ndarray
+                 Angles with X,Y,Z axes in degrees. All zeros point Z axis.
+    mode       : str
+                 Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.
+
+    Returns
+    ----------
+    ray        : ndarray
+                 Created ray.
+    """
+    if len(point.shape) == 1:
+        point = point.reshape((1, 3))
+    new_point = np.zeros(point.shape)
+    new_point[:, 2] += 5.
+    new_point = rotate_points(new_point, angles, mode=mode, offset=point[:, 0])
+    ray = create_ray_from_two_points(point, new_point)
+    if ray.shape[0] == 1:
+        ray = ray.reshape((2, 3))
+    return ray
+
+
+
+ +
+ +
+ + +

+ create_ray_from_two_points(x0y0z0, x1y1z1) + +

+ + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : list
+                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
+    x1y1z1       : list
+                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)
+    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.reshape((1, 3))
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.reshape((1, 3))
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)
+    s[s == 0] = np.nan
+    cosines = np.zeros((xdiff.shape[0], 3))
+    cosines[:, 0] = xdiff/s
+    cosines[:, 1] = ydiff/s
+    cosines[:, 2] = zdiff/s
+    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    if ray.shape[0] == 1:
+        ray = ray.reshape((2, 3))
+    return ray
+
+
+
+ +
+ +
+ + +

+ find_nearest_points(ray0, ray1) + +

+ + +
+ +

Find the nearest points on given rays with respect to the other ray.

+ + +

Parameters:

+
    +
  • + ray0 + – +
    +
         A ray.
    +
    +
    +
  • +
  • + ray1 + – +
    +
         A ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +c0 ( ndarray +) – +
    +

    Closest point on ray0.

    +
    +
  • +
  • +c1 ( ndarray +) – +
    +

    Closest point on ray1.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def find_nearest_points(ray0, ray1):
+    """
+    Find the nearest points on given rays with respect to the other ray.
+
+    Parameters
+    ----------
+    ray0       : ndarray
+                 A ray.
+    ray1       : ndarray
+                 A ray.
+
+    Returns
+    ----------
+    c0         : ndarray
+                 Closest point on ray0.
+    c1         : ndarray
+                 Closest point on ray1.
+    """
+    p0 = ray0[0].reshape(3,)
+    d0 = ray0[1].reshape(3,)
+    p1 = ray1[0].reshape(3,)
+    d1 = ray1[1].reshape(3,)
+    n = np.cross(d0, d1)
+    if np.all(n) == 0:
+        point, distances = calculate_intersection_of_two_rays(ray0, ray1)
+        c0 = c1 = point
+    else:
+        n0 = np.cross(d0, n)
+        n1 = np.cross(d1, n)
+        c0 = p0+(np.dot((p1-p0), n1)/np.dot(d0, n1))*d0
+        c1 = p1+(np.dot((p0-p1), n0)/np.dot(d1, n0))*d1
+    return c0, c1
+
+
+
+ +
+ +
+ + +

+ propagate_a_ray(ray, distance) + +

+ + +
+ +

Definition to propagate a ray at a certain given distance.

+ + +

Parameters:

+
    +
  • + ray + – +
    +
         A ray.
    +
    +
    +
  • +
  • + distance + – +
    +
         Distance.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_ray ( ndarray +) – +
    +

    Propagated ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
def propagate_a_ray(ray, distance):
+    """
+    Definition to propagate a ray at a certain given distance.
+
+    Parameters
+    ----------
+    ray        : ndarray
+                 A ray.
+    distance   : float
+                 Distance.
+
+    Returns
+    ----------
+    new_ray    : ndarray
+                 Propagated ray.
+    """
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    new_ray = np.copy(ray)
+    new_ray[:, 0, 0] = distance*new_ray[:, 1, 0] + new_ray[:, 0, 0]
+    new_ray[:, 0, 1] = distance*new_ray[:, 1, 1] + new_ray[:, 0, 1]
+    new_ray[:, 0, 2] = distance*new_ray[:, 1, 2] + new_ray[:, 0, 2]
+    if new_ray.shape[0] == 1:
+        new_ray = new_ray.reshape((2, 3))
+    return new_ray
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/tools/index.html b/odak/tools/index.html new file mode 100644 index 00000000..c82d8a7a --- /dev/null +++ b/odak/tools/index.html @@ -0,0 +1,18820 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.tools - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.tools

+ +
+ + + + +
+ +

odak.tools

+

Provides necessary definitions for general tools used across the library.

+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ latex + + +

+ + +
+ + +

A class to work with latex documents.

+ + + + + + +
+ Source code in odak/tools/latex.py +
  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
class latex():
+    """
+    A class to work with latex documents.
+    """
+    def __init__(
+                 self,
+                 filename
+                ):
+        """
+        Parameters
+        ----------
+        filename     : str
+                       Source filename (i.e. sample.tex).
+        """
+        self.filename = filename
+        self.content = read_text_file(self.filename)
+        self.content_type = []
+        self.latex_dictionary = [
+                                 '\\documentclass',
+                                 '\\if',
+                                 '\\pdf',
+                                 '\\else',
+                                 '\\fi',
+                                 '\\vgtc',
+                                 '\\teaser',
+                                 '\\abstract',
+                                 '\\CCS',
+                                 '\\usepackage',
+                                 '\\PassOptionsToPackage',
+                                 '\\definecolor',
+                                 '\\AtBeginDocument',
+                                 '\\providecommand',
+                                 '\\setcopyright',
+                                 '\\copyrightyear',
+                                 '\\acmYear',
+                                 '\\citestyle',
+                                 '\\newcommand',
+                                 '\\acmDOI',
+                                 '\\newabbreviation',
+                                 '\\global',
+                                 '\\begin{document}',
+                                 '\\author',
+                                 '\\affiliation',
+                                 '\\email',
+                                 '\\institution',
+                                 '\\streetaddress',
+                                 '\\city',
+                                 '\\country',
+                                 '\\postcode',
+                                 '\\ccsdesc',
+                                 '\\received',
+                                 '\\includegraphics',
+                                 '\\caption',
+                                 '\\centering',
+                                 '\\label',
+                                 '\\maketitle',
+                                 '\\toprule',
+                                 '\\multirow',
+                                 '\\multicolumn',
+                                 '\\cmidrule',
+                                 '\\addlinespace',
+                                 '\\midrule',
+                                 '\\cellcolor',
+                                 '\\bibliography',
+                                 '}',
+                                 '\\title',
+                                 '</ccs2012>',
+                                 '\\bottomrule',
+                                 '<concept>',
+                                 '<concept',
+                                 '<ccs',
+                                 '\\item',
+                                 '</concept',
+                                 '\\begin{abstract}',
+                                 '\\end{abstract}',
+                                 '\\endinput',
+                                 '\\\\'
+                                ]
+        self.latex_begin_dictionary = [
+                                       '\\begin{figure}',
+                                       '\\begin{figure*}',
+                                       '\\begin{equation}',
+                                       '\\begin{CCSXML}',
+                                       '\\begin{teaserfigure}',
+                                       '\\begin{table*}',
+                                       '\\begin{table}',
+                                       '\\begin{gather}',
+                                       '\\begin{align}',
+                                      ]
+        self.latex_end_dictionary = [
+                                     '\\end{figure}',
+                                     '\\end{figure*}',
+                                     '\\end{equation}',
+                                     '\\end{CCSXML}',
+                                     '\\end{teaserfigure}',
+                                     '\\end{table*}',
+                                     '\\end{table}',
+                                     '\\end{gather}',
+                                     '\\end{align}',
+                                    ]
+        self._label_lines()
+
+
+    def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+        """
+        Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+        Parameters
+        ----------
+        begin_dictionary     : list
+                               Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).
+        end_dictionary       : list
+                               Pythonic list containing latex syntax for end commands (i.e. \\end{table}).
+        syntax_dictionary    : list
+                               Pythonic list containing latex syntax (i.e. \\item).
+
+        """
+        self.latex_begin_dictionary = begin_dictionary
+        self.latex_end_dictionary = end_dictionary
+        self.latex_dictionary = syntax_dictionary
+        self._label_lines
+
+
+    def _label_lines(self):
+        """
+        Internal function for labelling lines.
+        """
+        content_type_flag = False
+        for line_id, line in enumerate(self.content):
+            while len(line) > 0 and line[0] == ' ':
+                 line = line[1::]
+            self.content[line_id] = line
+            if len(line) == 0:
+                content_type = 'empty'
+            elif line[0] == '%':
+                content_type = 'comment'
+            else:
+                content_type = 'text'
+            for syntax in self.latex_begin_dictionary:
+                if line.find(syntax) != -1:
+                    content_type_flag = True
+                    content_type = 'latex'
+            for syntax in self.latex_dictionary:
+                if line.find(syntax) != -1:
+                    content_type = 'latex'
+            if content_type_flag == True:
+                content_type = 'latex'
+                for syntax in self.latex_end_dictionary:
+                    if line.find(syntax) != -1:
+                         content_type_flag = False
+            self.content_type.append(content_type)
+
+
+    def get_line_count(self):
+        """
+        Definition to get the line count.
+
+
+        Returns
+        -------
+        line_count     : int
+                         Number of lines in the loaded latex document.
+        """
+        self.line_count = len(self.content)
+        return self.line_count
+
+
+    def get_line(self, line_id = 0):
+        """
+        Definition to get a specific line by inputting a line nunber.
+
+
+        Returns
+        ----------
+        line           : str
+                         Requested line.
+        content_type   : str
+                         Line's content type (e.g., latex, comment, text).
+        """
+        line = self.content[line_id]
+        content_type = self.content_type[line_id]
+        return line, content_type
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(filename) + +

+ + +
+ + + +

Parameters:

+
    +
  • + filename + – +
    +
           Source filename (i.e. sample.tex).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def __init__(
+             self,
+             filename
+            ):
+    """
+    Parameters
+    ----------
+    filename     : str
+                   Source filename (i.e. sample.tex).
+    """
+    self.filename = filename
+    self.content = read_text_file(self.filename)
+    self.content_type = []
+    self.latex_dictionary = [
+                             '\\documentclass',
+                             '\\if',
+                             '\\pdf',
+                             '\\else',
+                             '\\fi',
+                             '\\vgtc',
+                             '\\teaser',
+                             '\\abstract',
+                             '\\CCS',
+                             '\\usepackage',
+                             '\\PassOptionsToPackage',
+                             '\\definecolor',
+                             '\\AtBeginDocument',
+                             '\\providecommand',
+                             '\\setcopyright',
+                             '\\copyrightyear',
+                             '\\acmYear',
+                             '\\citestyle',
+                             '\\newcommand',
+                             '\\acmDOI',
+                             '\\newabbreviation',
+                             '\\global',
+                             '\\begin{document}',
+                             '\\author',
+                             '\\affiliation',
+                             '\\email',
+                             '\\institution',
+                             '\\streetaddress',
+                             '\\city',
+                             '\\country',
+                             '\\postcode',
+                             '\\ccsdesc',
+                             '\\received',
+                             '\\includegraphics',
+                             '\\caption',
+                             '\\centering',
+                             '\\label',
+                             '\\maketitle',
+                             '\\toprule',
+                             '\\multirow',
+                             '\\multicolumn',
+                             '\\cmidrule',
+                             '\\addlinespace',
+                             '\\midrule',
+                             '\\cellcolor',
+                             '\\bibliography',
+                             '}',
+                             '\\title',
+                             '</ccs2012>',
+                             '\\bottomrule',
+                             '<concept>',
+                             '<concept',
+                             '<ccs',
+                             '\\item',
+                             '</concept',
+                             '\\begin{abstract}',
+                             '\\end{abstract}',
+                             '\\endinput',
+                             '\\\\'
+                            ]
+    self.latex_begin_dictionary = [
+                                   '\\begin{figure}',
+                                   '\\begin{figure*}',
+                                   '\\begin{equation}',
+                                   '\\begin{CCSXML}',
+                                   '\\begin{teaserfigure}',
+                                   '\\begin{table*}',
+                                   '\\begin{table}',
+                                   '\\begin{gather}',
+                                   '\\begin{align}',
+                                  ]
+    self.latex_end_dictionary = [
+                                 '\\end{figure}',
+                                 '\\end{figure*}',
+                                 '\\end{equation}',
+                                 '\\end{CCSXML}',
+                                 '\\end{teaserfigure}',
+                                 '\\end{table*}',
+                                 '\\end{table}',
+                                 '\\end{gather}',
+                                 '\\end{align}',
+                                ]
+    self._label_lines()
+
+
+
+ +
+ +
+ + +

+ get_line(line_id=0) + +

+ + +
+ +

Definition to get a specific line by inputting a line nunber.

+ + +

Returns:

+
    +
  • +line ( str +) – +
    +

    Requested line.

    +
    +
  • +
  • +content_type ( str +) – +
    +

    Line's content type (e.g., latex, comment, text).

    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def get_line(self, line_id = 0):
+    """
+    Definition to get a specific line by inputting a line nunber.
+
+
+    Returns
+    ----------
+    line           : str
+                     Requested line.
+    content_type   : str
+                     Line's content type (e.g., latex, comment, text).
+    """
+    line = self.content[line_id]
+    content_type = self.content_type[line_id]
+    return line, content_type
+
+
+
+ +
+ +
+ + +

+ get_line_count() + +

+ + +
+ +

Definition to get the line count.

+ + +

Returns:

+
    +
  • +line_count ( int +) – +
    +

    Number of lines in the loaded latex document.

    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def get_line_count(self):
+    """
+    Definition to get the line count.
+
+
+    Returns
+    -------
+    line_count     : int
+                     Number of lines in the loaded latex document.
+    """
+    self.line_count = len(self.content)
+    return self.line_count
+
+
+
+ +
+ +
+ + +

+ set_latex_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary) + +

+ + +
+ +

Set document specific dictionaries so that the lines could be labelled in accordance.

+ + +

Parameters:

+
    +
  • + begin_dictionary + – +
    +
                   Pythonic list containing latex syntax for begin commands (i.e. \begin{align}).
    +
    +
    +
  • +
  • + end_dictionary + – +
    +
                   Pythonic list containing latex syntax for end commands (i.e. \end{table}).
    +
    +
    +
  • +
  • + syntax_dictionary + – +
    +
                   Pythonic list containing latex syntax (i.e. \item).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+    """
+    Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+    Parameters
+    ----------
+    begin_dictionary     : list
+                           Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).
+    end_dictionary       : list
+                           Pythonic list containing latex syntax for end commands (i.e. \\end{table}).
+    syntax_dictionary    : list
+                           Pythonic list containing latex syntax (i.e. \\item).
+
+    """
+    self.latex_begin_dictionary = begin_dictionary
+    self.latex_end_dictionary = end_dictionary
+    self.latex_dictionary = syntax_dictionary
+    self._label_lines
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ markdown + + +

+ + +
+ + +

A class to work with markdown documents.

+ + + + + + +
+ Source code in odak/tools/markdown.py +
  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
class markdown():
+    """
+    A class to work with markdown documents.
+    """
+    def __init__(
+                 self,
+                 filename
+                ):
+        """
+        Parameters
+        ----------
+        filename     : str
+                       Source filename (i.e. sample.md).
+        """
+        self.filename = filename
+        self.content = read_text_file(self.filename)
+        self.content_type = []
+        self.markdown_dictionary = [
+                                     '#',
+                                   ]
+        self.markdown_begin_dictionary = [
+                                          '```bash',
+                                          '```python',
+                                          '```',
+                                         ]
+        self.markdown_end_dictionary = [
+                                        '```',
+                                       ]
+        self._label_lines()
+
+
+    def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+        """
+        Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+        Parameters
+        ----------
+        begin_dictionary     : list
+                               Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
+        end_dictionary       : list
+                               Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
+        syntax_dictionary    : list
+                               Pythonic list containing markdown syntax (i.e. \\item).
+
+        """
+        self.markdown_begin_dictionary = begin_dictionary
+        self.markdown_end_dictionary = end_dictionary
+        self.markdown_dictionary = syntax_dictionary
+        self._label_lines
+
+
+    def _label_lines(self):
+        """
+        Internal function for labelling lines.
+        """
+        content_type_flag = False
+        for line_id, line in enumerate(self.content):
+            while len(line) > 0 and line[0] == ' ':
+                 line = line[1::]
+            self.content[line_id] = line
+            if len(line) == 0:
+                content_type = 'empty'
+            elif line[0] == '%':
+                content_type = 'comment'
+            else:
+                content_type = 'text'
+            for syntax in self.markdown_begin_dictionary:
+                if line.find(syntax) != -1:
+                    content_type_flag = True
+                    content_type = 'markdown'
+            for syntax in self.markdown_dictionary:
+                if line.find(syntax) != -1:
+                    content_type = 'markdown'
+            if content_type_flag == True:
+                content_type = 'markdown'
+                for syntax in self.markdown_end_dictionary:
+                    if line.find(syntax) != -1:
+                         content_type_flag = False
+            self.content_type.append(content_type)
+
+
+    def get_line_count(self):
+        """
+        Definition to get the line count.
+
+
+        Returns
+        -------
+        line_count     : int
+                         Number of lines in the loaded markdown document.
+        """
+        self.line_count = len(self.content)
+        return self.line_count
+
+
+    def get_line(self, line_id = 0):
+        """
+        Definition to get a specific line by inputting a line nunber.
+
+
+        Returns
+        ----------
+        line           : str
+                         Requested line.
+        content_type   : str
+                         Line's content type (e.g., markdown, comment, text).
+        """
+        line = self.content[line_id]
+        content_type = self.content_type[line_id]
+        return line, content_type
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(filename) + +

+ + +
+ + + +

Parameters:

+
    +
  • + filename + – +
    +
           Source filename (i.e. sample.md).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def __init__(
+             self,
+             filename
+            ):
+    """
+    Parameters
+    ----------
+    filename     : str
+                   Source filename (i.e. sample.md).
+    """
+    self.filename = filename
+    self.content = read_text_file(self.filename)
+    self.content_type = []
+    self.markdown_dictionary = [
+                                 '#',
+                               ]
+    self.markdown_begin_dictionary = [
+                                      '```bash',
+                                      '```python',
+                                      '```',
+                                     ]
+    self.markdown_end_dictionary = [
+                                    '```',
+                                   ]
+    self._label_lines()
+
+
+
+ +
+ +
+ + +

+ get_line(line_id=0) + +

+ + +
+ +

Definition to get a specific line by inputting a line nunber.

+ + +

Returns:

+
    +
  • +line ( str +) – +
    +

    Requested line.

    +
    +
  • +
  • +content_type ( str +) – +
    +

    Line's content type (e.g., markdown, comment, text).

    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
def get_line(self, line_id = 0):
+    """
+    Definition to get a specific line by inputting a line nunber.
+
+
+    Returns
+    ----------
+    line           : str
+                     Requested line.
+    content_type   : str
+                     Line's content type (e.g., markdown, comment, text).
+    """
+    line = self.content[line_id]
+    content_type = self.content_type[line_id]
+    return line, content_type
+
+
+
+ +
+ +
+ + +

+ get_line_count() + +

+ + +
+ +

Definition to get the line count.

+ + +

Returns:

+
    +
  • +line_count ( int +) – +
    +

    Number of lines in the loaded markdown document.

    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
def get_line_count(self):
+    """
+    Definition to get the line count.
+
+
+    Returns
+    -------
+    line_count     : int
+                     Number of lines in the loaded markdown document.
+    """
+    self.line_count = len(self.content)
+    return self.line_count
+
+
+
+ +
+ +
+ + +

+ set_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary) + +

+ + +
+ +

Set document specific dictionaries so that the lines could be labelled in accordance.

+ + +

Parameters:

+
    +
  • + begin_dictionary + – +
    +
                   Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
    +
    +
    +
  • +
  • + end_dictionary + – +
    +
                   Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
    +
    +
    +
  • +
  • + syntax_dictionary + – +
    +
                   Pythonic list containing markdown syntax (i.e. \item).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+    """
+    Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+    Parameters
+    ----------
+    begin_dictionary     : list
+                           Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
+    end_dictionary       : list
+                           Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
+    syntax_dictionary    : list
+                           Pythonic list containing markdown syntax (i.e. \\item).
+
+    """
+    self.markdown_begin_dictionary = begin_dictionary
+    self.markdown_end_dictionary = end_dictionary
+    self.markdown_dictionary = syntax_dictionary
+    self._label_lines
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ batch_of_rays(entry, exit) + +

+ + +
+ +

Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.

+ + +

Parameters:

+
    +
  • + entry + – +
    +
         Either a single point with size of 3 or multiple points with the size of nx3.
    +
    +
    +
  • +
  • + exit + – +
    +
         Either a single point with size of 3 or multiple points with the size of nx3.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rays ( ndarray +) – +
    +

    Generated batch of rays.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def batch_of_rays(entry, exit):
+    """
+    Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.
+
+    Parameters
+    ----------
+    entry      : ndarray
+                 Either a single point with size of 3 or multiple points with the size of nx3.
+    exit       : ndarray
+                 Either a single point with size of 3 or multiple points with the size of nx3.
+
+    Returns
+    ----------
+    rays       : ndarray
+                 Generated batch of rays.
+    """
+    norays = np.array([0, 0])
+    if len(entry.shape) == 1:
+        entry = entry.reshape((1, 3))
+    if len(exit.shape) == 1:
+        exit = exit.reshape((1, 3))
+    norays = np.amax(np.asarray([entry.shape[0], exit.shape[0]]))
+    if norays > exit.shape[0]:
+        exit = np.repeat(exit, norays, axis=0)
+    elif norays > entry.shape[0]:
+        entry = np.repeat(entry, norays, axis=0)
+    rays = []
+    norays = int(norays)
+    for i in range(norays):
+        rays.append(
+            create_ray_from_two_points(
+                entry[i],
+                exit[i]
+            )
+        )
+    rays = np.asarray(rays)
+    return rays
+
+
+
+ +
+ +
+ + +

+ blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]) + +

+ + +
+ +

A definition to blur a field using a Gaussian kernel.

+ + +

Parameters:

+
    +
  • + field + – +
    +
            MxN field.
    +
    +
    +
  • +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +blurred_field ( ndarray +) – +
    +

    Blurred field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]):
+    """
+    A definition to blur a field using a Gaussian kernel.
+
+    Parameters
+    ----------
+    field         : ndarray
+                    MxN field.
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+
+    Returns
+    ----------
+    blurred_field : ndarray
+                    Blurred field.
+    """
+    kernel = generate_2d_gaussian(kernel_length, nsigma)
+    kernel = zero_pad(kernel, field.shape)
+    blurred_field = convolve2d(field, kernel)
+    blurred_field = blurred_field/np.amax(blurred_field)
+    return blurred_field
+
+
+
+ +
+ +
+ + +

+ box_volume_sample(no=[10, 10, 10], size=[100.0, 100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples in a box volume.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the volume.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the volume.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the volume.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def box_volume_sample(no=[10, 10, 10], size=[100., 100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples in a box volume.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the volume.
+    center      : list
+                  Center location of the volume.
+    angles      : list
+                  Tilt of the volume.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], no[2], 3))
+    x, y, z = np.mgrid[0:no[0], 0:no[1], 0:no[2]]
+    step = [
+        size[0]/no[0],
+        size[1]/no[1],
+        size[2]/no[2]
+    ]
+    samples[:, :, :, 0] = x*step[0]+step[0]/2.-size[0]/2.
+    samples[:, :, :, 1] = y*step[1]+step[1]/2.-size[1]/2.
+    samples[:, :, :, 2] = z*step[2]+step[2]/2.-size[2]/2.
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1]*samples.shape[2], samples.shape[3]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ check_directory(directory) + +

+ + +
+ +

Definition to check if a directory exist. If it doesn't exist, this definition will create one.

+ + +

Parameters:

+
    +
  • + directory + – +
    +
            Full directory path.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def check_directory(directory):
+    """
+    Definition to check if a directory exist. If it doesn't exist, this definition will create one.
+
+
+    Parameters
+    ----------
+    directory     : str
+                    Full directory path.
+    """
+    if not os.path.exists(expanduser(directory)):
+        os.makedirs(expanduser(directory))
+        return False
+    return True
+
+
+
+ +
+ +
+ + +

+ circular_sample(no=[10, 10], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples inside a circle over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_sample(no=[10, 10], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples inside a circle over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0]+1, no[1]+1, 3))
+    r_angles, r = np.mgrid[0:no[0]+1, 0:no[1]+1]
+    r = r/np.amax(r)*radius
+    r_angles = r_angles/np.amax(r_angles)*np.pi*2
+    samples[:, :, 0] = r*np.cos(r_angles)
+    samples[:, :, 1] = r*np.sin(r_angles)
+    samples = samples[1:no[0]+1, 1:no[1]+1, :]
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1], samples.shape[2]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ circular_uniform_random_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate sample inside a circle uniformly but randomly.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_uniform_random_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """ 
+    Definition to generate sample inside a circle uniformly but randomly.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.empty((0, 3))
+    rs = np.sqrt(np.random.uniform(0, 1, no[0]))
+    angs = np.random.uniform(0, 2*np.pi, no[1])
+    for i in rs:
+        for angle in angs:
+            r = radius*i
+            point = np.array(
+                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])
+            samples = np.vstack((samples, point))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ circular_uniform_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate sample inside a circle uniformly.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_uniform_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate sample inside a circle uniformly.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.empty((0, 3))
+    for i in range(0, no[0]):
+        r = i/no[0]*radius
+        ang_no = no[1]*i/no[0]
+        for j in range(0, int(no[1]*i/no[0])):
+            angle = j/ang_no*2*np.pi
+            point = np.array(
+                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])
+            samples = np.vstack((samples, point))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ closest_point_to_a_ray(point, ray) + +

+ + +
+ +

Definition to calculate the point on a ray that is closest to given point.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Given point in X,Y,Z.
    +
    +
    +
  • +
  • + ray + – +
    +
            Given ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +closest_point ( ndarray +) – +
    +

    Calculated closest point.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def closest_point_to_a_ray(point, ray):
+    """
+    Definition to calculate the point on a ray that is closest to given point.
+
+    Parameters
+    ----------
+    point         : list
+                    Given point in X,Y,Z.
+    ray           : ndarray
+                    Given ray.
+
+    Returns
+    ---------
+    closest_point : ndarray
+                    Calculated closest point.
+    """
+    from odak.raytracing import propagate_a_ray
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    p0 = ray[:, 0]
+    p1 = propagate_a_ray(ray, 1.)
+    if len(p1.shape) == 2:
+        p1 = p1.reshape((1, 2, 3))
+    p1 = p1[:, 0]
+    p1 = p1.reshape(3)
+    p0 = p0.reshape(3)
+    point = point.reshape(3)
+    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)
+    closest_point = propagate_a_ray(ray, closest_distance)[0]
+    return closest_point
+
+
+
+ +
+ +
+ + +

+ convert_bytes(num) + +

+ + +
+ +

A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.

+ + +

Parameters:

+
    +
  • + num + – +
    +
         Size in bytes
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +num ( float +) – +
    +

    Size in new unit.

    +
    +
  • +
  • +x ( str +) – +
    +

    New unit bytes, KB, MB, GB or TB.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def convert_bytes(num):
+    """
+    A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.
+
+
+    Parameters
+    ----------
+    num        : float
+                 Size in bytes
+
+
+    Returns
+    ----------
+    num        : float
+                 Size in new unit.
+    x          : str
+                 New unit bytes, KB, MB, GB or TB.
+    """
+    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
+        if num < 1024.0:
+            return num, x
+        num /= 1024.0
+    return None, None
+
+
+
+ +
+ +
+ + +

+ convert_to_numpy(a) + +

+ + +
+ +

A definition to convert Torch to Numpy.

+ + +

Parameters:

+
    +
  • + a + – +
    +
         Input Torch array.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +b ( ndarray +) – +
    +

    Converted array.

    +
    +
  • +
+ +
+ Source code in odak/tools/conversions.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def convert_to_numpy(a):
+    """
+    A definition to convert Torch to Numpy.
+
+    Parameters
+    ----------
+    a          : torch.Tensor
+                 Input Torch array.
+
+    Returns
+    ----------
+    b          : numpy.ndarray
+                 Converted array.
+    """
+    b = a.to('cpu').detach().numpy()
+    return b
+
+
+
+ +
+ +
+ + +

+ convert_to_torch(a, grad=True) + +

+ + +
+ +

A definition to convert Numpy arrays to Torch.

+ + +

Parameters:

+
    +
  • + a + – +
    +
         Input Numpy array.
    +
    +
    +
  • +
  • + grad + – +
    +
         Set if the converted array requires gradient.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +c ( Tensor +) – +
    +

    Converted array.

    +
    +
  • +
+ +
+ Source code in odak/tools/conversions.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
def convert_to_torch(a, grad=True):
+    """
+    A definition to convert Numpy arrays to Torch.
+
+    Parameters
+    ----------
+    a          : ndarray
+                 Input Numpy array.
+    grad       : bool
+                 Set if the converted array requires gradient.
+
+    Returns
+    ----------
+    c          : torch.Tensor
+                 Converted array.
+    """
+    b = np.copy(a)
+    c = torch.from_numpy(b)
+    c.requires_grad_(grad)
+    return c
+
+
+
+ +
+ +
+ + +

+ convolve2d(field, kernel) + +

+ + +
+ +

Definition to convolve a field with a kernel by multiplying in frequency space.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field with MxN shape.
    +
    +
    +
  • +
  • + kernel + – +
    +
          Input kernel with MxN shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( ndarray +) – +
    +

    Convolved field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def convolve2d(field, kernel):
+    """
+    Definition to convolve a field with a kernel by multiplying in frequency space.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field with MxN shape.
+    kernel      : ndarray
+                  Input kernel with MxN shape.
+
+    Returns
+    ----------
+    new_field   : ndarray
+                  Convolved field.
+    """
+    fr = np.fft.fft2(field)
+    fr2 = np.fft.fft2(np.flipud(np.fliplr(kernel)))
+    m, n = fr.shape
+    new_field = np.real(np.fft.ifft2(fr*fr2))
+    new_field = np.roll(new_field, int(-m/2+1), axis=0)
+    new_field = np.roll(new_field, int(-n/2+1), axis=1)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ copy_file(source, destination, follow_symlinks=True) + +

+ + +
+ +

Definition to copy a file from one location to another.

+ + +

Parameters:

+
    +
  • + source + – +
    +
              Source filename.
    +
    +
    +
  • +
  • + destination + – +
    +
              Destination filename.
    +
    +
    +
  • +
  • + follow_symlinks + (bool, default: + True +) + – +
    +
              Set to True to follow the source of symbolic links.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def copy_file(source, destination, follow_symlinks = True):
+    """
+    Definition to copy a file from one location to another.
+
+
+
+    Parameters
+    ----------
+    source          : str
+                      Source filename.
+    destination     : str
+                      Destination filename.
+    follow_symlinks : bool
+                      Set to True to follow the source of symbolic links.
+    """
+    return shutil.copyfile(
+                           expanduser(source),
+                           expanduser(source),
+                           follow_symlinks = follow_symlinks
+                          )
+
+
+
+ +
+ +
+ + +

+ create_empty_list(dimensions=[1, 1]) + +

+ + +
+ +

A definition to create an empty Pythonic list.

+ + +

Parameters:

+
    +
  • + dimensions + – +
    +
           Dimensions of the list to be created.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_list ( list +) – +
    +

    New empty list.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
def create_empty_list(dimensions = [1, 1]):
+    """
+    A definition to create an empty Pythonic list.
+
+    Parameters
+    ----------
+    dimensions   : list
+                   Dimensions of the list to be created.
+
+    Returns
+    -------
+    new_list     : list
+                   New empty list.
+    """
+    new_list = 0
+    for n in reversed(dimensions):
+        new_list = [new_list] * n
+    return new_list
+
+
+
+ +
+ +
+ + +

+ create_ray_from_two_points(x0y0z0, x1y1z1) + +

+ + +
+ +

Definition to create a ray from two given points. Note that both inputs must match in shape.

+ + +

Parameters:

+
    +
  • + x0y0z0 + – +
    +
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
    +
    +
    +
  • +
  • + x1y1z1 + – +
    +
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/raytracing/ray.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def create_ray_from_two_points(x0y0z0, x1y1z1):
+    """
+    Definition to create a ray from two given points. Note that both inputs must match in shape.
+
+    Parameters
+    ----------
+    x0y0z0       : list
+                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.
+    x1y1z1       : list
+                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)
+    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)
+    if len(x0y0z0.shape) == 1:
+        x0y0z0 = x0y0z0.reshape((1, 3))
+    if len(x1y1z1.shape) == 1:
+        x1y1z1 = x1y1z1.reshape((1, 3))
+    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]
+    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]
+    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]
+    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)
+    s[s == 0] = np.nan
+    cosines = np.zeros((xdiff.shape[0], 3))
+    cosines[:, 0] = xdiff/s
+    cosines[:, 1] = ydiff/s
+    cosines[:, 2] = zdiff/s
+    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)
+    ray[:, 0] = x0y0z0
+    ray[:, 1] = cosines
+    if ray.shape[0] == 1:
+        ray = ray.reshape((2, 3))
+    return ray
+
+
+
+ +
+ +
+ + +

+ crop_center(field, size=None) + +

+ + +
+ +

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field 2Mx2N array.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cropped ( ndarray +) – +
    +

    Cropped version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def crop_center(field, size=None):
+    """
+    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field 2Mx2N array.
+
+    Returns
+    ----------
+    cropped     : ndarray
+                  Cropped version of the input field.
+    """
+    if type(size) == type(None):
+        qx = int(np.ceil(field.shape[0])/4)
+        qy = int(np.ceil(field.shape[1])/4)
+        cropped = np.copy(field[qx:3*qx, qy:3*qy])
+    else:
+        cx = int(np.ceil(field.shape[0]/2))
+        cy = int(np.ceil(field.shape[1]/2))
+        hx = int(np.ceil(size[0]/2))
+        hy = int(np.ceil(size[1]/2))
+        cropped = np.copy(field[cx-hx:cx+hx, cy-hy:cy+hy])
+    return cropped
+
+
+
+ +
+ +
+ + +

+ cross_product(vector1, vector2) + +

+ + +
+ +

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

+ + +

Parameters:

+
    +
  • + vector1 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + vector2 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def cross_product(vector1, vector2):
+    """
+    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product
+
+    Parameters
+    ----------
+    vector1      : ndarray
+                   A vector/ray.
+    vector2      : ndarray
+                   A vector/ray.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    angle = np.cross(vector1[1].T, vector2[1].T)
+    angle = np.asarray(angle)
+    ray = np.array([vector1[0], angle], dtype=np.float32)
+    return ray
+
+
+
+ +
+ +
+ + +

+ distance_between_point_clouds(points0, points1) + +

+ + +
+ +

A definition to find distance between every point in one cloud to other points in the other point cloud.

+ + +

Parameters:

+
    +
  • + points0 + – +
    +
          Mx3 points.
    +
    +
    +
  • +
  • + points1 + – +
    +
          Nx3 points.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distances ( ndarray +) – +
    +

    MxN distances.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
def distance_between_point_clouds(points0, points1):
+    """
+    A definition to find distance between every point in one cloud to other points in the other point cloud.
+    Parameters
+    ----------
+    points0     : ndarray
+                  Mx3 points.
+    points1     : ndarray
+                  Nx3 points.
+
+    Returns
+    ----------
+    distances   : ndarray
+                  MxN distances.
+    """
+    c = points1.reshape((1, points1.shape[0], points1.shape[1]))
+    a = np.repeat(c, points0.shape[0], axis=0)
+    b = points0.reshape((points0.shape[0], 1, points0.shape[1]))
+    b = np.repeat(b, a.shape[1], axis=1)
+    distances = np.sqrt(np.sum((a-b)**2, axis=2))
+    return distances
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : list
+                  First point in X,Y,Z.
+    point2      : list
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : float
+                  Distance in between given two points.
+    """
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = np.sqrt(np.sum((point1-point2)**2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))
+    return distance
+
+
+
+ +
+ +
+ + +

+ expanduser(filename) + +

+ + +
+ +

Definition to decode filename using namespaces and shortcuts.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_filename ( str +) – +
    +

    Filename.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def expanduser(filename):
+    """
+    Definition to decode filename using namespaces and shortcuts.
+
+
+    Parameters
+    ----------
+    filename      : str
+                    Filename.
+
+
+    Returns
+    -------
+    new_filename  : str
+                    Filename.
+    """
+    new_filename = os.path.expanduser(filename)
+    return new_filename
+
+
+
+ +
+ +
+ + +

+ generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]) + +

+ + +
+ +

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( ndarray +) – +
    +

    Generated Gaussian kernel.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]):
+    """
+    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+
+    Returns
+    ----------
+    kernel_2d     : ndarray
+                    Generated Gaussian kernel.
+    """
+    x = np.linspace(-nsigma[0], nsigma[0], kernel_length[0]+1)
+    y = np.linspace(-nsigma[1], nsigma[1], kernel_length[1]+1)
+    xx, yy = np.meshgrid(x, y)
+    kernel_2d = np.exp(-0.5*(np.square(xx) /
+                       np.square(nsigma[0]) + np.square(yy)/np.square(nsigma[1])))
+    kernel_2d = kernel_2d/kernel_2d.sum()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ generate_bandlimits(size=[512, 512], levels=9) + +

+ + +
+ +

A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.

+ + +

Parameters:

+
    +
  • + size + – +
    +
         Size of each mask in octaves.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +masks ( ndarray +) – +
    +

    Masks (Octaves).

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def generate_bandlimits(size=[512, 512], levels=9):
+    """
+    A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.
+
+    Parameters
+    ----------
+    size       : list
+                 Size of each mask in octaves.
+
+    Returns
+    ----------
+    masks      : ndarray
+                 Masks (Octaves).
+    """
+    masks = np.zeros((levels, size[0], size[1]))
+    cx = int(size[0]/2)
+    cy = int(size[1]/2)
+    for i in range(0, masks.shape[0]):
+        deltax = int((size[0])/(2**(i+1)))
+        deltay = int((size[1])/(2**(i+1)))
+        masks[
+            i,
+            cx-deltax:cx+deltax,
+            cy-deltay:cy+deltay
+        ] = 1.
+        masks[
+            i,
+            int(cx-deltax/2.):int(cx+deltax/2.),
+            int(cy-deltay/2.):int(cy+deltay/2.)
+        ] = 0.
+    masks = np.asarray(masks)
+    return masks
+
+
+
+ +
+ +
+ + +

+ grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def grid_sample(no=[10, 10], size=[100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    step = [
+        size[0]/(no[0]-1),
+        size[1]/(no[1]-1)
+    ]
+    x, y = np.mgrid[0:no[0], 0:no[1]]
+    samples[:, :, 0] = x*step[0]-size[0]/2.
+    samples[:, :, 1] = y*step[1]-size[1]/2.
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1], samples.shape[2]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ list_files(path, key='*.*', recursive=True) + +

+ + +
+ +

Definition to list files in a given path with a given key.

+ + +

Parameters:

+
    +
  • + path + – +
    +
          Path to a folder.
    +
    +
    +
  • +
  • + key + – +
    +
          Key used for scanning a path.
    +
    +
    +
  • +
  • + recursive + – +
    +
          If set True, scan the path recursively.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +files_list ( ndarray +) – +
    +

    list of files found in a given path.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def list_files(path, key = '*.*', recursive = True):
+    """
+    Definition to list files in a given path with a given key.
+
+
+    Parameters
+    ----------
+    path        : str
+                  Path to a folder.
+    key         : str
+                  Key used for scanning a path.
+    recursive   : bool
+                  If set True, scan the path recursively.
+
+
+    Returns
+    ----------
+    files_list  : ndarray
+                  list of files found in a given path.
+    """
+    if recursive == True:
+        search_result = pathlib.Path(expanduser(path)).rglob(key)
+    elif recursive == False:
+        search_result = pathlib.Path(expanduser(path)).glob(key)
+    files_list = []
+    for item in search_result:
+        files_list.append(str(item))
+    files_list = sorted(files_list)
+    return files_list
+
+
+
+ +
+ +
+ + +

+ load_dictionary(filename) + +

+ + +
+ +

Definition to load a dictionary (JSON) file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +settings ( dict +) – +
    +

    Dictionary read from the file.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def load_dictionary(filename):
+    """
+    Definition to load a dictionary (JSON) file.
+
+
+    Parameters
+    ----------
+    filename      : str
+                    Filename.
+
+
+    Returns
+    ----------
+    settings      : dict
+                    Dictionary read from the file.
+
+    """
+    settings = json.load(open(expanduser(filename)))
+    return settings
+
+
+
+ +
+ +
+ + +

+ load_image(fn, normalizeby=0.0, torch_style=False) + +

+ + +
+ +

Definition to load an image from a given location as a Numpy array.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + normalizeby + – +
    +
           Value to to normalize images with. Default value of zero will lead to no normalization.
    +
    +
    +
  • +
  • + torch_style + – +
    +
           If set True, it will load an image mxnx3 as 3xmxn.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( ndarray +) – +
    +

    Image loaded as a Numpy array.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def load_image(fn, normalizeby = 0., torch_style = False):
+    """ 
+    Definition to load an image from a given location as a Numpy array.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    normalizeby  : float
+                   Value to to normalize images with. Default value of zero will lead to no normalization.
+    torch_style  : bool
+                   If set True, it will load an image mxnx3 as 3xmxn.
+
+
+    Returns
+    ----------
+    image        :  ndarray
+                    Image loaded as a Numpy array.
+
+    """
+    image = cv2.imread(expanduser(fn), cv2.IMREAD_UNCHANGED)
+    if isinstance(image, type(None)):
+         logging.warning('Image not properly loaded. Check filename or image type.')    
+         sys.exit()
+    if len(image.shape) > 2:
+        new_image = np.copy(image)
+        new_image[:, :, 0] = image[:, :, 2]
+        new_image[:, :, 2] = image[:, :, 0]
+        image = new_image
+    if normalizeby != 0.:
+        image = image * 1. / normalizeby
+    if torch_style == True and len(image.shape) > 2:
+        image = np.moveaxis(image, -1, 0)
+    return image.astype(float)
+
+
+
+ +
+ +
+ + +

+ nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Size.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Inverse NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list
+                  Size.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  Inverse NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    result = finufft.nufft2d2(
+        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)
+    if type(size) == type(None):
+        result = result.reshape(field.shape)
+    else:
+        result = result.reshape(size)
+    return result
+
+
+
+ +
+ +
+ + +

+ nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Shape of the NUFFT calculated for an input field.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list or ndarray
+                  Shape of the NUFFT calculated for an input field.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    if type(size) == type(None):
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            image.shape,
+            eps=eps,
+            isign=sign
+        )
+    else:
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            (size[0], size[1]),
+            eps=eps,
+            isign=sign
+        )
+    result = np.asarray(result)
+    return result
+
+
+
+ +
+ +
+ + +

+ point_to_ray_distance(point, ray_point_0, ray_point_1) + +

+ + +
+ +

Definition to find point's closest distance to a line represented with two points.

+ + +

Parameters:

+
    +
  • + point + – +
    +
          Point to be tested.
    +
    +
    +
  • +
  • + ray_point_0 + (ndarray) + – +
    +
          First point to represent a line.
    +
    +
    +
  • +
  • + ray_point_1 + (ndarray) + – +
    +
          Second point to represent a line.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Calculated distance.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def point_to_ray_distance(point, ray_point_0, ray_point_1):
+    """
+    Definition to find point's closest distance to a line represented with two points.
+
+    Parameters
+    ----------
+    point       : ndarray
+                  Point to be tested.
+    ray_point_0 : ndarray
+                  First point to represent a line.
+    ray_point_1 : ndarray
+                  Second point to represent a line.
+
+    Returns
+    ----------
+    distance    : float
+                  Calculated distance.
+    """
+    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))
+                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)
+    return distance
+
+
+
+ +
+ +
+ + +

+ quantize(image_field, bits=4) + +

+ + +
+ +

Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.

+ + +

Parameters:

+
    +
  • + image_field + (ndarray) + – +
    +
          Input image field.
    +
    +
    +
  • +
  • + bits + – +
    +
          A value in between 0 to 8. Can not be zero.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( ndarray +) – +
    +

    Quantized image field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def quantize(image_field, bits=4):
+    """
+    Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.
+
+    Parameters
+    ----------
+    image_field : ndarray
+                  Input image field.
+    bits        : int
+                  A value in between 0 to 8. Can not be zero.
+
+    Returns
+    ----------
+    new_field   : ndarray
+                  Quantized image field.
+    """
+    divider = 2**(8-bits)
+    new_field = image_field/divider
+    new_field = new_field.astype(np.int64)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ random_sample_point_cloud(point_cloud, no, p=None) + +

+ + +
+ +

Definition to pull a subset of points from a point cloud with a given probability.

+ + +

Parameters:

+
    +
  • + point_cloud + – +
    +
           Point cloud array.
    +
    +
    +
  • +
  • + no + – +
    +
           Number of samples.
    +
    +
    +
  • +
  • + p + – +
    +
           Probability list in the same size as no.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +subset ( ndarray +) – +
    +

    Subset of the given point cloud.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
def random_sample_point_cloud(point_cloud, no, p=None):
+    """
+    Definition to pull a subset of points from a point cloud with a given probability.
+
+    Parameters
+    ----------
+    point_cloud  : ndarray
+                   Point cloud array.
+    no           : list
+                   Number of samples.
+    p            : list
+                   Probability list in the same size as no.
+
+    Returns
+    ----------
+    subset       : ndarray
+                   Subset of the given point cloud.
+    """
+    choice = np.random.choice(point_cloud.shape[0], no, p)
+    subset = point_cloud[choice, :]
+    return subset
+
+
+
+ +
+ +
+ + +

+ read_PLY(fn, offset=[0, 0, 0], angles=[0.0, 0.0, 0.0], mode='XYZ') + +

+ + +
+ +

Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename of a PLY file.
    +
    +
    +
  • +
  • + offset + – +
    +
           Offset in X,Y,Z.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +triangles ( ndarray +) – +
    +

    Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.

    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def read_PLY(fn, offset=[0, 0, 0], angles=[0., 0., 0.], mode='XYZ'):
+    """
+    Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    fn           : string
+                   Filename of a PLY file.
+    offset       : ndarray
+                   Offset in X,Y,Z.
+    angles       : list
+                   Rotation angles in degrees.
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes. 
+
+    Returns
+    ----------
+    triangles    : ndarray
+                  Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.
+    """
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+    else:
+        np_ply = np
+    with open(fn, 'rb') as f:
+        plydata = PlyData.read(f)
+    triangle_ids = np_ply.vstack(plydata['face'].data['vertex_indices'])
+    triangles = []
+    for vertex_ids in triangle_ids:
+        triangle = [
+            rotate_point(plydata['vertex'][int(vertex_ids[0])
+                                           ].tolist(), angles=angles, offset=offset)[0],
+            rotate_point(plydata['vertex'][int(vertex_ids[1])
+                                           ].tolist(), angles=angles, offset=offset)[0],
+            rotate_point(plydata['vertex'][int(vertex_ids[2])
+                                           ].tolist(), angles=angles, offset=offset)[0]
+        ]
+        triangle = np_ply.asarray(triangle)
+        triangles.append(triangle)
+    triangles = np_ply.array(triangles)
+    triangles = np.asarray(triangles, dtype=np.float32)
+    return triangles
+
+
+
+ +
+ +
+ + +

+ read_PLY_point_cloud(filename) + +

+ + +
+ +

Definition to read a PLY file as a point cloud.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
           Filename of a PLY file.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point_cloud ( ndarray +) – +
    +

    An array filled with poitns from the PLY file.

    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def read_PLY_point_cloud(filename):
+    """
+    Definition to read a PLY file as a point cloud.
+
+    Parameters
+    ----------
+    filename     : str
+                   Filename of a PLY file.
+
+    Returns
+    ----------
+    point_cloud  : ndarray
+                   An array filled with poitns from the PLY file.
+    """
+    plydata = PlyData.read(filename)
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+        point_cloud = np_ply.zeros((plydata['vertex'][:].shape[0], 3))
+        point_cloud[:, 0] = np_ply.asarray(plydata['vertex']['x'][:])
+        point_cloud[:, 1] = np_ply.asarray(plydata['vertex']['y'][:])
+        point_cloud[:, 2] = np_ply.asarray(plydata['vertex']['z'][:])
+        point_cloud = np.asarray(point_cloud)
+    else:
+        point_cloud = np.zeros((plydata['vertex'][:].shape[0], 3))
+        point_cloud[:, 0] = np.asarray(plydata['vertex']['x'][:])
+        point_cloud[:, 1] = np.asarray(plydata['vertex']['y'][:])
+        point_cloud[:, 2] = np.asarray(plydata['vertex']['z'][:])
+    return point_cloud
+
+
+
+ +
+ +
+ + +

+ read_text_file(filename) + +

+ + +
+ +

Definition to read a given text file and convert it into a Pythonic list.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
              Source filename (i.e. test.txt).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +content ( list +) – +
    +

    Pythonic string list containing the text from the file provided.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def read_text_file(filename):
+    """
+    Definition to read a given text file and convert it into a Pythonic list.
+
+
+    Parameters
+    ----------
+    filename        : str
+                      Source filename (i.e. test.txt).
+
+
+    Returns
+    -------
+    content         : list
+                      Pythonic string list containing the text from the file provided.
+    """
+    content = []
+    loaded_file = open(expanduser(filename))
+    while line := loaded_file.readline():
+        content.append(line.rstrip())
+    return content
+
+
+
+ +
+ +
+ + +

+ resize_image(img, target_size) + +

+ + +
+ +

Definition to resize a given image to a target shape.

+ + +

Parameters:

+
    +
  • + img + – +
    +
            MxN image to be resized.
    +        Image must be normalized (0-1).
    +
    +
    +
  • +
  • + target_size + – +
    +
            Target shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +img ( ndarray +) – +
    +

    Resized image.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
def resize_image(img, target_size):
+    """
+    Definition to resize a given image to a target shape.
+
+
+    Parameters
+    ----------
+    img           : ndarray
+                    MxN image to be resized.
+                    Image must be normalized (0-1).
+    target_size   : list
+                    Target shape.
+
+
+    Returns
+    ----------
+    img           : ndarray
+                    Resized image.
+
+    """
+    img = cv2.resize(img, dsize=(target_size[0], target_size[1]), interpolation=cv2.INTER_AREA)
+    return img
+
+
+
+ +
+ +
+ + +

+ rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
  • +rotx ( ndarray +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
  • +roty ( ndarray +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
  • +rotz ( ndarray +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : ndarray
+                   A point.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation
+    rotx         : ndarray
+                   Rotation matrix along X axis.
+    roty         : ndarray
+                   Rotation matrix along Y axis.
+    rotz         : ndarray
+                   Rotation matrix along Z axis.
+    """
+    point = np.asarray(point)
+    point -= np.asarray(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))
+    result += np.asarray(origin)
+    result += np.asarray(offset)
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate points.

+ + +

Parameters:

+
    +
  • + points + – +
    +
           Points.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate points.
+
+    Parameters
+    ----------
+    points       : ndarray
+                   Points.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation   
+    """
+    points = np.asarray(points)
+    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:
+        result = np.array(offset) + points
+        return result
+    points -= np.array(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T
+    result += np.array(origin)
+    result += np.array(offset)
+    return result
+
+
+
+ +
+ +
+ + +

+ rotmatx(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotx ( ndarray +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
def rotmatx(angle):
+    """
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    rotx         : ndarray
+                   Rotation matrix along X axis.
+    """
+    angle = np.float64(angle)
+    angle = np.radians(angle)
+    rotx = np.array([
+        [1.,               0.,               0.],
+        [0.,  math.cos(angle), -math.sin(angle)],
+        [0.,  math.sin(angle),  math.cos(angle)]
+    ], dtype=np.float64)
+    return rotx
+
+
+
+ +
+ +
+ + +

+ rotmaty(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Y axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +roty ( ndarray +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
def rotmaty(angle):
+    """
+    Definition to generate a rotation matrix along Y axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    roty         : ndarray
+                   Rotation matrix along Y axis.
+    """
+    angle = np.radians(angle)
+    roty = np.array([
+        [math.cos(angle),  0., math.sin(angle)],
+        [0.,               1.,              0.],
+        [-math.sin(angle), 0., math.cos(angle)]
+    ], dtype=np.float64)
+    return roty
+
+
+
+ +
+ +
+ + +

+ rotmatz(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Z axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotz ( ndarray +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
def rotmatz(angle):
+    """
+    Definition to generate a rotation matrix along Z axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    rotz         : ndarray
+                   Rotation matrix along Z axis.
+    """
+    angle = np.radians(angle)
+    rotz = np.array([
+        [math.cos(angle), -math.sin(angle), 0.],
+        [math.sin(angle),  math.cos(angle), 0.],
+        [0.,               0., 1.]
+    ], dtype=np.float64)
+
+    return rotz
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = np.subtract(b, a)
+    p1a = np.subtract(p1, a)
+    p2a = np.subtract(p2, a)
+    cp1 = np.cross(ba, p1a)
+    cp2 = np.cross(ba, p2a)
+    test = np.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ +
+ + +

+ save_dictionary(settings, filename) + +

+ + +
+ +

Definition to load a dictionary (JSON) file.

+ + +

Parameters:

+
    +
  • + settings + – +
    +
            Dictionary read from the file.
    +
    +
    +
  • +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def save_dictionary(settings, filename):
+    """
+    Definition to load a dictionary (JSON) file.
+
+
+    Parameters
+    ----------
+    settings      : dict
+                    Dictionary read from the file.
+    filename      : str
+                    Filename.
+    """
+    with open(expanduser(filename), 'w', encoding='utf-8') as f:
+        json.dump(settings, f, ensure_ascii=False, indent=4)
+    return settings
+
+
+
+ +
+ +
+ + +

+ save_image(fn, img, cmin=0, cmax=255, color_depth=8) + +

+ + +
+ +

Definition to save a Numpy array as an image.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + img + – +
    +
           A numpy array with NxMx3 or NxMx1 shapes.
    +
    +
    +
  • +
  • + cmin + – +
    +
           Minimum value that will be interpreted as 0 level in the final image.
    +
    +
    +
  • +
  • + cmax + – +
    +
           Maximum value that will be interpreted as 255 level in the final image.
    +
    +
    +
  • +
  • + color_depth + – +
    +
           Pixel color depth in bits, default is eight bits.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +bool ( bool +) – +
    +

    True if successful.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):
+    """
+    Definition to save a Numpy array as an image.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    img          : ndarray
+                   A numpy array with NxMx3 or NxMx1 shapes.
+    cmin         : int
+                   Minimum value that will be interpreted as 0 level in the final image.
+    cmax         : int
+                   Maximum value that will be interpreted as 255 level in the final image.
+    color_depth  : int
+                   Pixel color depth in bits, default is eight bits.
+
+
+    Returns
+    ----------
+    bool         :  bool
+                    True if successful.
+
+    """
+    input_img = np.copy(img).astype(np.float32)
+    cmin = float(cmin)
+    cmax = float(cmax)
+    input_img[input_img < cmin] = cmin
+    input_img[input_img > cmax] = cmax
+    input_img /= cmax
+    input_img = input_img * 1. * (2**color_depth - 1)
+    if color_depth == 8:
+        input_img = input_img.astype(np.uint8)
+    elif color_depth == 16:
+        input_img = input_img.astype(np.uint16)
+    if len(input_img.shape) > 2:
+        if input_img.shape[2] > 1:
+            cache_img = np.copy(input_img)
+            cache_img[:, :, 0] = input_img[:, :, 2]
+            cache_img[:, :, 2] = input_img[:, :, 0]
+            input_img = cache_img
+    cv2.imwrite(expanduser(fn), input_img)
+    return True
+
+
+
+ +
+ +
+ + +

+ shell_command(cmd, cwd='.', timeout=None, check=True) + +

+ + +
+ +

Definition to initiate shell commands.

+ + +

Parameters:

+
    +
  • + cmd + – +
    +
           Command to be executed.
    +
    +
    +
  • +
  • + cwd + – +
    +
           Working directory.
    +
    +
    +
  • +
  • + timeout + – +
    +
           Timeout if the process isn't complete in the given number of seconds.
    +
    +
    +
  • +
  • + check + – +
    +
           Set it to True to return the results and to enable timeout.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +proc ( Popen +) – +
    +

    Generated process.

    +
    +
  • +
  • +outs ( str +) – +
    +

    Outputs of the executed command, returns None when check is set to False.

    +
    +
  • +
  • +errs ( str +) – +
    +

    Errors of the executed command, returns None when check is set to False.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def shell_command(cmd, cwd = '.', timeout = None, check = True):
+    """
+    Definition to initiate shell commands.
+
+
+    Parameters
+    ----------
+    cmd          : list
+                   Command to be executed. 
+    cwd          : str
+                   Working directory.
+    timeout      : int
+                   Timeout if the process isn't complete in the given number of seconds.
+    check        : bool
+                   Set it to True to return the results and to enable timeout.
+
+
+    Returns
+    ----------
+    proc         : subprocess.Popen
+                   Generated process.
+    outs         : str
+                   Outputs of the executed command, returns None when check is set to False.
+    errs         : str
+                   Errors of the executed command, returns None when check is set to False.
+
+    """
+    for item_id in range(len(cmd)):
+        cmd[item_id] = expanduser(cmd[item_id])
+    proc = subprocess.Popen(
+                            cmd,
+                            cwd = cwd,
+                            stdout = subprocess.PIPE
+                           )
+    if check == False:
+        return proc, None, None
+    try:
+        outs, errs = proc.communicate(timeout = timeout)
+    except subprocess.TimeoutExpired:
+        proc.kill()
+        outs, errs = proc.communicate()
+    return proc, outs, errs
+
+
+
+ +
+ +
+ + +

+ size_of_a_file(file_path) + +

+ + +
+ +

A definition to get size of a file with a relevant unit.

+ + +

Parameters:

+
    +
  • + file_path + – +
    +
         Path of the file.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +a ( float +) – +
    +

    Size of the file.

    +
    +
  • +
  • +b ( str +) – +
    +

    Unit of the size (bytes, KB, MB, GB or TB).

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def size_of_a_file(file_path):
+    """
+    A definition to get size of a file with a relevant unit.
+
+
+    Parameters
+    ----------
+    file_path  : float
+                 Path of the file.
+
+
+    Returns
+    ----------
+    a          : float
+                 Size of the file.
+    b          : str
+                 Unit of the size (bytes, KB, MB, GB or TB).
+    """
+    if os.path.isfile(file_path):
+        file_info = os.stat(file_path)
+        a, b = convert_bytes(file_info.st_size)
+        return a, b
+    return None, None
+
+
+
+ +
+ +
+ + +

+ sphere_sample(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2]) + +

+ + +
+ +

Definition to generate a regular sample set on the surface of a sphere using polar coordinates.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of a sphere.
    +
    +
    +
  • +
  • + center + – +
    +
          Center of a sphere.
    +
    +
    +
  • +
  • + k + – +
    +
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
def sphere_sample(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):
+    """
+    Definition to generate a regular sample set on the surface of a sphere using polar coordinates.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of a sphere.
+    center      : list
+                  Center of a sphere.
+    k           : list
+                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    psi, teta = np.mgrid[0:no[0], 0:no[1]]
+    psi = k[0]*np.pi/no[0]*psi
+    teta = k[1]*np.pi/no[1]*teta
+    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)
+    samples[:, :, 1] = center[0]+radius*np.sin(psi)*np.sin(teta)
+    samples[:, :, 2] = center[0]+radius*np.cos(psi)
+    samples = samples.reshape((no[0]*no[1], 3))
+    return samples
+
+
+
+ +
+ +
+ + +

+ sphere_sample_uniform(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2]) + +

+ + +
+ +

Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of a sphere.
    +
    +
    +
  • +
  • + center + – +
    +
          Center of a sphere.
    +
    +
    +
  • +
  • + k + – +
    +
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
def sphere_sample_uniform(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):
+    """
+    Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of a sphere.
+    center      : list
+                  Center of a sphere.
+    k           : list
+                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
+
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    row = np.arange(0, no[0])
+    psi, teta = np.mgrid[0:no[0], 0:no[1]]
+    for psi_id in range(0, no[0]):
+        psi[psi_id] = np.roll(row, psi_id, axis=0)
+        teta[psi_id] = np.roll(row, -psi_id, axis=0)
+    psi = k[0]*np.pi/no[0]*psi
+    teta = k[1]*np.pi/no[1]*teta
+    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)
+    samples[:, :, 1] = center[1]+radius*np.sin(psi)*np.sin(teta)
+    samples[:, :, 2] = center[2]+radius*np.cos(psi)
+    samples = samples.reshape((no[0]*no[1], 3))
+    return samples
+
+
+
+ +
+ +
+ + +

+ tilt_towards(location, lookat) + +

+ + +
+ +

Definition to tilt surface normal of a plane towards a point.

+ + +

Parameters:

+
    +
  • + location + – +
    +
           Center of the plane to be tilted.
    +
    +
    +
  • +
  • + lookat + – +
    +
           Tilt towards this point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +angles ( list +) – +
    +

    Rotation angles in degrees.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def tilt_towards(location, lookat):
+    """
+    Definition to tilt surface normal of a plane towards a point.
+
+    Parameters
+    ----------
+    location     : list
+                   Center of the plane to be tilted.
+    lookat       : list
+                   Tilt towards this point.
+
+    Returns
+    ----------
+    angles       : list
+                   Rotation angles in degrees.
+    """
+    dx = location[0]-lookat[0]
+    dy = location[1]-lookat[1]
+    dz = location[2]-lookat[2]
+    dist = np.sqrt(dx**2+dy**2+dz**2)
+    phi = np.arctan2(dy, dx)
+    theta = np.arccos(dz/dist)
+    angles = [
+        0,
+        np.degrees(theta).tolist(),
+        np.degrees(phi).tolist()
+    ]
+    return angles
+
+
+
+ +
+ +
+ + +

+ write_PLY(triangles, savefn='output.ply') + +

+ + +
+ +

Definition to generate a PLY file from given points.

+ + +

Parameters:

+
    +
  • + triangles + – +
    +
          List of triangles with the size of Mx3x3.
    +
    +
    +
  • +
  • + savefn + – +
    +
          Filename for a PLY file.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
def write_PLY(triangles, savefn = 'output.ply'):
+    """
+    Definition to generate a PLY file from given points.
+
+    Parameters
+    ----------
+    triangles   : ndarray
+                  List of triangles with the size of Mx3x3.
+    savefn      : string
+                  Filename for a PLY file.
+    """
+    tris = []
+    pnts = []
+    color = [255, 255, 255]
+    for tri_id in range(triangles.shape[0]):
+        tris.append(
+            (
+                [3*tri_id, 3*tri_id+1, 3*tri_id+2],
+                color[0],
+                color[1],
+                color[2]
+            )
+        )
+        for i in range(0, 3):
+            pnts.append(
+                (
+                    float(triangles[tri_id][i][0]),
+                    float(triangles[tri_id][i][1]),
+                    float(triangles[tri_id][i][2])
+                )
+            )
+    tris = np.asarray(tris, dtype=[
+                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+    # Save mesh.
+    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])
+    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])
+    PlyData([el1, el2], text="True").write(savefn)
+
+
+
+ +
+ +
+ + +

+ write_PLY_from_points(points, savefn='output.ply') + +

+ + +
+ +

Definition to generate a PLY file from given points.

+ + +

Parameters:

+
    +
  • + points + – +
    +
          List of points with the size of MxNx3.
    +
    +
    +
  • +
  • + savefn + – +
    +
          Filename for a PLY file.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
def write_PLY_from_points(points, savefn='output.ply'):
+    """
+    Definition to generate a PLY file from given points.
+
+    Parameters
+    ----------
+    points      : ndarray
+                  List of points with the size of MxNx3.
+    savefn      : string
+                  Filename for a PLY file.
+
+    """
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+    else:
+        np_ply = np
+    # Generate equation
+    samples = [points.shape[0], points.shape[1]]
+    # Generate vertices.
+    pnts = []
+    tris = []
+    for idx in range(0, samples[0]):
+        for idy in range(0, samples[1]):
+            pnt = (points[idx, idy, 0],
+                   points[idx, idy, 1], points[idx, idy, 2])
+            pnts.append(pnt)
+    color = [255, 255, 255]
+    for idx in range(0, samples[0]-1):
+        for idy in range(0, samples[1]-1):
+            tris.append(([idy+(idx+1)*samples[0], idy+idx*samples[0],
+                        idy+1+idx*samples[0]], color[0], color[1], color[2]))
+            tris.append(([idy+(idx+1)*samples[0], idy+1+idx*samples[0],
+                        idy+1+(idx+1)*samples[0]], color[0], color[1], color[2]))
+    tris = np_ply.asarray(tris, dtype=[(
+        'vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+    pnts = np_ply.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+    # Save mesh.
+    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])
+    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])
+    PlyData([el1, el2], text="True").write(savefn)
+
+
+
+ +
+ +
+ + +

+ write_to_text_file(content, filename, write_flag='w') + +

+ + +
+ +

Defininition to write a Pythonic list to a text file.

+ + +

Parameters:

+
    +
  • + content + – +
    +
              Pythonic string list to be written to a file.
    +
    +
    +
  • +
  • + filename + – +
    +
              Destination filename (i.e. test.txt).
    +
    +
    +
  • +
  • + write_flag + – +
    +
              Defines the interaction with the file. 
    +          The default is "w" (overwrite any existing content).
    +          For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def write_to_text_file(content, filename, write_flag = 'w'):
+    """
+    Defininition to write a Pythonic list to a text file.
+
+
+    Parameters
+    ----------
+    content         : list
+                      Pythonic string list to be written to a file.
+    filename        : str
+                      Destination filename (i.e. test.txt).
+    write_flag      : str
+                      Defines the interaction with the file. 
+                      The default is "w" (overwrite any existing content).
+                      For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files
+    """
+    with open(expanduser(filename), write_flag) as f:
+        for line in content:
+            f.write('{}\n'.format(line))
+    return True
+
+
+
+ +
+ +
+ + +

+ zero_pad(field, size=None, method='center') + +

+ + +
+ +

Definition to zero pad a MxN array to 2Mx2N array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
                Input field MxN array.
    +
    +
    +
  • +
  • + size + – +
    +
                Size to be zeropadded.
    +
    +
    +
  • +
  • + method + – +
    +
                Zeropad either by placing the content to center or to the left.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field_zero_padded ( ndarray +) – +
    +

    Zeropadded version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def zero_pad(field, size=None, method='center'):
+    """
+    Definition to zero pad a MxN array to 2Mx2N array.
+
+    Parameters
+    ----------
+    field             : ndarray
+                        Input field MxN array.
+    size              : list
+                        Size to be zeropadded.
+    method            : str
+                        Zeropad either by placing the content to center or to the left.
+
+    Returns
+    ----------
+    field_zero_padded : ndarray
+                        Zeropadded version of the input field.
+    """
+    if type(size) == type(None):
+        hx = int(np.ceil(field.shape[0])/2)
+        hy = int(np.ceil(field.shape[1])/2)
+    else:
+        hx = int(np.ceil((size[0]-field.shape[0])/2))
+        hy = int(np.ceil((size[1]-field.shape[1])/2))
+    if method == 'center':
+        field_zero_padded = np.pad(
+            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))
+    elif method == 'left aligned':
+        field_zero_padded = np.pad(
+            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))
+    if type(size) != type(None):
+        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]
+    return field_zero_padded
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ read_PLY(fn, offset=[0, 0, 0], angles=[0.0, 0.0, 0.0], mode='XYZ') + +

+ + +
+ +

Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename of a PLY file.
    +
    +
    +
  • +
  • + offset + – +
    +
           Offset in X,Y,Z.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +triangles ( ndarray +) – +
    +

    Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.

    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
def read_PLY(fn, offset=[0, 0, 0], angles=[0., 0., 0.], mode='XYZ'):
+    """
+    Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    fn           : string
+                   Filename of a PLY file.
+    offset       : ndarray
+                   Offset in X,Y,Z.
+    angles       : list
+                   Rotation angles in degrees.
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes. 
+
+    Returns
+    ----------
+    triangles    : ndarray
+                  Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.
+    """
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+    else:
+        np_ply = np
+    with open(fn, 'rb') as f:
+        plydata = PlyData.read(f)
+    triangle_ids = np_ply.vstack(plydata['face'].data['vertex_indices'])
+    triangles = []
+    for vertex_ids in triangle_ids:
+        triangle = [
+            rotate_point(plydata['vertex'][int(vertex_ids[0])
+                                           ].tolist(), angles=angles, offset=offset)[0],
+            rotate_point(plydata['vertex'][int(vertex_ids[1])
+                                           ].tolist(), angles=angles, offset=offset)[0],
+            rotate_point(plydata['vertex'][int(vertex_ids[2])
+                                           ].tolist(), angles=angles, offset=offset)[0]
+        ]
+        triangle = np_ply.asarray(triangle)
+        triangles.append(triangle)
+    triangles = np_ply.array(triangles)
+    triangles = np.asarray(triangles, dtype=np.float32)
+    return triangles
+
+
+
+ +
+ +
+ + +

+ read_PLY_point_cloud(filename) + +

+ + +
+ +

Definition to read a PLY file as a point cloud.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
           Filename of a PLY file.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +point_cloud ( ndarray +) – +
    +

    An array filled with poitns from the PLY file.

    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def read_PLY_point_cloud(filename):
+    """
+    Definition to read a PLY file as a point cloud.
+
+    Parameters
+    ----------
+    filename     : str
+                   Filename of a PLY file.
+
+    Returns
+    ----------
+    point_cloud  : ndarray
+                   An array filled with poitns from the PLY file.
+    """
+    plydata = PlyData.read(filename)
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+        point_cloud = np_ply.zeros((plydata['vertex'][:].shape[0], 3))
+        point_cloud[:, 0] = np_ply.asarray(plydata['vertex']['x'][:])
+        point_cloud[:, 1] = np_ply.asarray(plydata['vertex']['y'][:])
+        point_cloud[:, 2] = np_ply.asarray(plydata['vertex']['z'][:])
+        point_cloud = np.asarray(point_cloud)
+    else:
+        point_cloud = np.zeros((plydata['vertex'][:].shape[0], 3))
+        point_cloud[:, 0] = np.asarray(plydata['vertex']['x'][:])
+        point_cloud[:, 1] = np.asarray(plydata['vertex']['y'][:])
+        point_cloud[:, 2] = np.asarray(plydata['vertex']['z'][:])
+    return point_cloud
+
+
+
+ +
+ +
+ + +

+ write_PLY(triangles, savefn='output.ply') + +

+ + +
+ +

Definition to generate a PLY file from given points.

+ + +

Parameters:

+
    +
  • + triangles + – +
    +
          List of triangles with the size of Mx3x3.
    +
    +
    +
  • +
  • + savefn + – +
    +
          Filename for a PLY file.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
def write_PLY(triangles, savefn = 'output.ply'):
+    """
+    Definition to generate a PLY file from given points.
+
+    Parameters
+    ----------
+    triangles   : ndarray
+                  List of triangles with the size of Mx3x3.
+    savefn      : string
+                  Filename for a PLY file.
+    """
+    tris = []
+    pnts = []
+    color = [255, 255, 255]
+    for tri_id in range(triangles.shape[0]):
+        tris.append(
+            (
+                [3*tri_id, 3*tri_id+1, 3*tri_id+2],
+                color[0],
+                color[1],
+                color[2]
+            )
+        )
+        for i in range(0, 3):
+            pnts.append(
+                (
+                    float(triangles[tri_id][i][0]),
+                    float(triangles[tri_id][i][1]),
+                    float(triangles[tri_id][i][2])
+                )
+            )
+    tris = np.asarray(tris, dtype=[
+                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+    # Save mesh.
+    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])
+    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])
+    PlyData([el1, el2], text="True").write(savefn)
+
+
+
+ +
+ +
+ + +

+ write_PLY_from_points(points, savefn='output.ply') + +

+ + +
+ +

Definition to generate a PLY file from given points.

+ + +

Parameters:

+
    +
  • + points + – +
    +
          List of points with the size of MxNx3.
    +
    +
    +
  • +
  • + savefn + – +
    +
          Filename for a PLY file.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/asset.py +
def write_PLY_from_points(points, savefn='output.ply'):
+    """
+    Definition to generate a PLY file from given points.
+
+    Parameters
+    ----------
+    points      : ndarray
+                  List of points with the size of MxNx3.
+    savefn      : string
+                  Filename for a PLY file.
+
+    """
+    if np.__name__ != 'numpy':
+        import numpy as np_ply
+    else:
+        np_ply = np
+    # Generate equation
+    samples = [points.shape[0], points.shape[1]]
+    # Generate vertices.
+    pnts = []
+    tris = []
+    for idx in range(0, samples[0]):
+        for idy in range(0, samples[1]):
+            pnt = (points[idx, idy, 0],
+                   points[idx, idy, 1], points[idx, idy, 2])
+            pnts.append(pnt)
+    color = [255, 255, 255]
+    for idx in range(0, samples[0]-1):
+        for idy in range(0, samples[1]-1):
+            tris.append(([idy+(idx+1)*samples[0], idy+idx*samples[0],
+                        idy+1+idx*samples[0]], color[0], color[1], color[2]))
+            tris.append(([idy+(idx+1)*samples[0], idy+1+idx*samples[0],
+                        idy+1+(idx+1)*samples[0]], color[0], color[1], color[2]))
+    tris = np_ply.asarray(tris, dtype=[(
+        'vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
+    pnts = np_ply.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
+    # Save mesh.
+    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])
+    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])
+    PlyData([el1, el2], text="True").write(savefn)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ convert_to_numpy(a) + +

+ + +
+ +

A definition to convert Torch to Numpy.

+ + +

Parameters:

+
    +
  • + a + – +
    +
         Input Torch array.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +b ( ndarray +) – +
    +

    Converted array.

    +
    +
  • +
+ +
+ Source code in odak/tools/conversions.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def convert_to_numpy(a):
+    """
+    A definition to convert Torch to Numpy.
+
+    Parameters
+    ----------
+    a          : torch.Tensor
+                 Input Torch array.
+
+    Returns
+    ----------
+    b          : numpy.ndarray
+                 Converted array.
+    """
+    b = a.to('cpu').detach().numpy()
+    return b
+
+
+
+ +
+ +
+ + +

+ convert_to_torch(a, grad=True) + +

+ + +
+ +

A definition to convert Numpy arrays to Torch.

+ + +

Parameters:

+
    +
  • + a + – +
    +
         Input Numpy array.
    +
    +
    +
  • +
  • + grad + – +
    +
         Set if the converted array requires gradient.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +c ( Tensor +) – +
    +

    Converted array.

    +
    +
  • +
+ +
+ Source code in odak/tools/conversions.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
def convert_to_torch(a, grad=True):
+    """
+    A definition to convert Numpy arrays to Torch.
+
+    Parameters
+    ----------
+    a          : ndarray
+                 Input Numpy array.
+    grad       : bool
+                 Set if the converted array requires gradient.
+
+    Returns
+    ----------
+    c          : torch.Tensor
+                 Converted array.
+    """
+    b = np.copy(a)
+    c = torch.from_numpy(b)
+    c.requires_grad_(grad)
+    return c
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ check_directory(directory) + +

+ + +
+ +

Definition to check if a directory exist. If it doesn't exist, this definition will create one.

+ + +

Parameters:

+
    +
  • + directory + – +
    +
            Full directory path.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def check_directory(directory):
+    """
+    Definition to check if a directory exist. If it doesn't exist, this definition will create one.
+
+
+    Parameters
+    ----------
+    directory     : str
+                    Full directory path.
+    """
+    if not os.path.exists(expanduser(directory)):
+        os.makedirs(expanduser(directory))
+        return False
+    return True
+
+
+
+ +
+ +
+ + +

+ convert_bytes(num) + +

+ + +
+ +

A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.

+ + +

Parameters:

+
    +
  • + num + – +
    +
         Size in bytes
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +num ( float +) – +
    +

    Size in new unit.

    +
    +
  • +
  • +x ( str +) – +
    +

    New unit bytes, KB, MB, GB or TB.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def convert_bytes(num):
+    """
+    A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.
+
+
+    Parameters
+    ----------
+    num        : float
+                 Size in bytes
+
+
+    Returns
+    ----------
+    num        : float
+                 Size in new unit.
+    x          : str
+                 New unit bytes, KB, MB, GB or TB.
+    """
+    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:
+        if num < 1024.0:
+            return num, x
+        num /= 1024.0
+    return None, None
+
+
+
+ +
+ +
+ + +

+ copy_file(source, destination, follow_symlinks=True) + +

+ + +
+ +

Definition to copy a file from one location to another.

+ + +

Parameters:

+
    +
  • + source + – +
    +
              Source filename.
    +
    +
    +
  • +
  • + destination + – +
    +
              Destination filename.
    +
    +
    +
  • +
  • + follow_symlinks + (bool, default: + True +) + – +
    +
              Set to True to follow the source of symbolic links.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def copy_file(source, destination, follow_symlinks = True):
+    """
+    Definition to copy a file from one location to another.
+
+
+
+    Parameters
+    ----------
+    source          : str
+                      Source filename.
+    destination     : str
+                      Destination filename.
+    follow_symlinks : bool
+                      Set to True to follow the source of symbolic links.
+    """
+    return shutil.copyfile(
+                           expanduser(source),
+                           expanduser(source),
+                           follow_symlinks = follow_symlinks
+                          )
+
+
+
+ +
+ +
+ + +

+ expanduser(filename) + +

+ + +
+ +

Definition to decode filename using namespaces and shortcuts.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_filename ( str +) – +
    +

    Filename.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def expanduser(filename):
+    """
+    Definition to decode filename using namespaces and shortcuts.
+
+
+    Parameters
+    ----------
+    filename      : str
+                    Filename.
+
+
+    Returns
+    -------
+    new_filename  : str
+                    Filename.
+    """
+    new_filename = os.path.expanduser(filename)
+    return new_filename
+
+
+
+ +
+ +
+ + +

+ list_files(path, key='*.*', recursive=True) + +

+ + +
+ +

Definition to list files in a given path with a given key.

+ + +

Parameters:

+
    +
  • + path + – +
    +
          Path to a folder.
    +
    +
    +
  • +
  • + key + – +
    +
          Key used for scanning a path.
    +
    +
    +
  • +
  • + recursive + – +
    +
          If set True, scan the path recursively.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +files_list ( ndarray +) – +
    +

    list of files found in a given path.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def list_files(path, key = '*.*', recursive = True):
+    """
+    Definition to list files in a given path with a given key.
+
+
+    Parameters
+    ----------
+    path        : str
+                  Path to a folder.
+    key         : str
+                  Key used for scanning a path.
+    recursive   : bool
+                  If set True, scan the path recursively.
+
+
+    Returns
+    ----------
+    files_list  : ndarray
+                  list of files found in a given path.
+    """
+    if recursive == True:
+        search_result = pathlib.Path(expanduser(path)).rglob(key)
+    elif recursive == False:
+        search_result = pathlib.Path(expanduser(path)).glob(key)
+    files_list = []
+    for item in search_result:
+        files_list.append(str(item))
+    files_list = sorted(files_list)
+    return files_list
+
+
+
+ +
+ +
+ + +

+ load_dictionary(filename) + +

+ + +
+ +

Definition to load a dictionary (JSON) file.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +settings ( dict +) – +
    +

    Dictionary read from the file.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def load_dictionary(filename):
+    """
+    Definition to load a dictionary (JSON) file.
+
+
+    Parameters
+    ----------
+    filename      : str
+                    Filename.
+
+
+    Returns
+    ----------
+    settings      : dict
+                    Dictionary read from the file.
+
+    """
+    settings = json.load(open(expanduser(filename)))
+    return settings
+
+
+
+ +
+ +
+ + +

+ load_image(fn, normalizeby=0.0, torch_style=False) + +

+ + +
+ +

Definition to load an image from a given location as a Numpy array.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + normalizeby + – +
    +
           Value to to normalize images with. Default value of zero will lead to no normalization.
    +
    +
    +
  • +
  • + torch_style + – +
    +
           If set True, it will load an image mxnx3 as 3xmxn.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +image ( ndarray +) – +
    +

    Image loaded as a Numpy array.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def load_image(fn, normalizeby = 0., torch_style = False):
+    """ 
+    Definition to load an image from a given location as a Numpy array.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    normalizeby  : float
+                   Value to to normalize images with. Default value of zero will lead to no normalization.
+    torch_style  : bool
+                   If set True, it will load an image mxnx3 as 3xmxn.
+
+
+    Returns
+    ----------
+    image        :  ndarray
+                    Image loaded as a Numpy array.
+
+    """
+    image = cv2.imread(expanduser(fn), cv2.IMREAD_UNCHANGED)
+    if isinstance(image, type(None)):
+         logging.warning('Image not properly loaded. Check filename or image type.')    
+         sys.exit()
+    if len(image.shape) > 2:
+        new_image = np.copy(image)
+        new_image[:, :, 0] = image[:, :, 2]
+        new_image[:, :, 2] = image[:, :, 0]
+        image = new_image
+    if normalizeby != 0.:
+        image = image * 1. / normalizeby
+    if torch_style == True and len(image.shape) > 2:
+        image = np.moveaxis(image, -1, 0)
+    return image.astype(float)
+
+
+
+ +
+ +
+ + +

+ read_text_file(filename) + +

+ + +
+ +

Definition to read a given text file and convert it into a Pythonic list.

+ + +

Parameters:

+
    +
  • + filename + – +
    +
              Source filename (i.e. test.txt).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +content ( list +) – +
    +

    Pythonic string list containing the text from the file provided.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def read_text_file(filename):
+    """
+    Definition to read a given text file and convert it into a Pythonic list.
+
+
+    Parameters
+    ----------
+    filename        : str
+                      Source filename (i.e. test.txt).
+
+
+    Returns
+    -------
+    content         : list
+                      Pythonic string list containing the text from the file provided.
+    """
+    content = []
+    loaded_file = open(expanduser(filename))
+    while line := loaded_file.readline():
+        content.append(line.rstrip())
+    return content
+
+
+
+ +
+ +
+ + +

+ resize_image(img, target_size) + +

+ + +
+ +

Definition to resize a given image to a target shape.

+ + +

Parameters:

+
    +
  • + img + – +
    +
            MxN image to be resized.
    +        Image must be normalized (0-1).
    +
    +
    +
  • +
  • + target_size + – +
    +
            Target shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +img ( ndarray +) – +
    +

    Resized image.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
def resize_image(img, target_size):
+    """
+    Definition to resize a given image to a target shape.
+
+
+    Parameters
+    ----------
+    img           : ndarray
+                    MxN image to be resized.
+                    Image must be normalized (0-1).
+    target_size   : list
+                    Target shape.
+
+
+    Returns
+    ----------
+    img           : ndarray
+                    Resized image.
+
+    """
+    img = cv2.resize(img, dsize=(target_size[0], target_size[1]), interpolation=cv2.INTER_AREA)
+    return img
+
+
+
+ +
+ +
+ + +

+ save_dictionary(settings, filename) + +

+ + +
+ +

Definition to load a dictionary (JSON) file.

+ + +

Parameters:

+
    +
  • + settings + – +
    +
            Dictionary read from the file.
    +
    +
    +
  • +
  • + filename + – +
    +
            Filename.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def save_dictionary(settings, filename):
+    """
+    Definition to load a dictionary (JSON) file.
+
+
+    Parameters
+    ----------
+    settings      : dict
+                    Dictionary read from the file.
+    filename      : str
+                    Filename.
+    """
+    with open(expanduser(filename), 'w', encoding='utf-8') as f:
+        json.dump(settings, f, ensure_ascii=False, indent=4)
+    return settings
+
+
+
+ +
+ +
+ + +

+ save_image(fn, img, cmin=0, cmax=255, color_depth=8) + +

+ + +
+ +

Definition to save a Numpy array as an image.

+ + +

Parameters:

+
    +
  • + fn + – +
    +
           Filename.
    +
    +
    +
  • +
  • + img + – +
    +
           A numpy array with NxMx3 or NxMx1 shapes.
    +
    +
    +
  • +
  • + cmin + – +
    +
           Minimum value that will be interpreted as 0 level in the final image.
    +
    +
    +
  • +
  • + cmax + – +
    +
           Maximum value that will be interpreted as 255 level in the final image.
    +
    +
    +
  • +
  • + color_depth + – +
    +
           Pixel color depth in bits, default is eight bits.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +bool ( bool +) – +
    +

    True if successful.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):
+    """
+    Definition to save a Numpy array as an image.
+
+
+    Parameters
+    ----------
+    fn           : str
+                   Filename.
+    img          : ndarray
+                   A numpy array with NxMx3 or NxMx1 shapes.
+    cmin         : int
+                   Minimum value that will be interpreted as 0 level in the final image.
+    cmax         : int
+                   Maximum value that will be interpreted as 255 level in the final image.
+    color_depth  : int
+                   Pixel color depth in bits, default is eight bits.
+
+
+    Returns
+    ----------
+    bool         :  bool
+                    True if successful.
+
+    """
+    input_img = np.copy(img).astype(np.float32)
+    cmin = float(cmin)
+    cmax = float(cmax)
+    input_img[input_img < cmin] = cmin
+    input_img[input_img > cmax] = cmax
+    input_img /= cmax
+    input_img = input_img * 1. * (2**color_depth - 1)
+    if color_depth == 8:
+        input_img = input_img.astype(np.uint8)
+    elif color_depth == 16:
+        input_img = input_img.astype(np.uint16)
+    if len(input_img.shape) > 2:
+        if input_img.shape[2] > 1:
+            cache_img = np.copy(input_img)
+            cache_img[:, :, 0] = input_img[:, :, 2]
+            cache_img[:, :, 2] = input_img[:, :, 0]
+            input_img = cache_img
+    cv2.imwrite(expanduser(fn), input_img)
+    return True
+
+
+
+ +
+ +
+ + +

+ shell_command(cmd, cwd='.', timeout=None, check=True) + +

+ + +
+ +

Definition to initiate shell commands.

+ + +

Parameters:

+
    +
  • + cmd + – +
    +
           Command to be executed.
    +
    +
    +
  • +
  • + cwd + – +
    +
           Working directory.
    +
    +
    +
  • +
  • + timeout + – +
    +
           Timeout if the process isn't complete in the given number of seconds.
    +
    +
    +
  • +
  • + check + – +
    +
           Set it to True to return the results and to enable timeout.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +proc ( Popen +) – +
    +

    Generated process.

    +
    +
  • +
  • +outs ( str +) – +
    +

    Outputs of the executed command, returns None when check is set to False.

    +
    +
  • +
  • +errs ( str +) – +
    +

    Errors of the executed command, returns None when check is set to False.

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def shell_command(cmd, cwd = '.', timeout = None, check = True):
+    """
+    Definition to initiate shell commands.
+
+
+    Parameters
+    ----------
+    cmd          : list
+                   Command to be executed. 
+    cwd          : str
+                   Working directory.
+    timeout      : int
+                   Timeout if the process isn't complete in the given number of seconds.
+    check        : bool
+                   Set it to True to return the results and to enable timeout.
+
+
+    Returns
+    ----------
+    proc         : subprocess.Popen
+                   Generated process.
+    outs         : str
+                   Outputs of the executed command, returns None when check is set to False.
+    errs         : str
+                   Errors of the executed command, returns None when check is set to False.
+
+    """
+    for item_id in range(len(cmd)):
+        cmd[item_id] = expanduser(cmd[item_id])
+    proc = subprocess.Popen(
+                            cmd,
+                            cwd = cwd,
+                            stdout = subprocess.PIPE
+                           )
+    if check == False:
+        return proc, None, None
+    try:
+        outs, errs = proc.communicate(timeout = timeout)
+    except subprocess.TimeoutExpired:
+        proc.kill()
+        outs, errs = proc.communicate()
+    return proc, outs, errs
+
+
+
+ +
+ +
+ + +

+ size_of_a_file(file_path) + +

+ + +
+ +

A definition to get size of a file with a relevant unit.

+ + +

Parameters:

+
    +
  • + file_path + – +
    +
         Path of the file.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +a ( float +) – +
    +

    Size of the file.

    +
    +
  • +
  • +b ( str +) – +
    +

    Unit of the size (bytes, KB, MB, GB or TB).

    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def size_of_a_file(file_path):
+    """
+    A definition to get size of a file with a relevant unit.
+
+
+    Parameters
+    ----------
+    file_path  : float
+                 Path of the file.
+
+
+    Returns
+    ----------
+    a          : float
+                 Size of the file.
+    b          : str
+                 Unit of the size (bytes, KB, MB, GB or TB).
+    """
+    if os.path.isfile(file_path):
+        file_info = os.stat(file_path)
+        a, b = convert_bytes(file_info.st_size)
+        return a, b
+    return None, None
+
+
+
+ +
+ +
+ + +

+ write_to_text_file(content, filename, write_flag='w') + +

+ + +
+ +

Defininition to write a Pythonic list to a text file.

+ + +

Parameters:

+
    +
  • + content + – +
    +
              Pythonic string list to be written to a file.
    +
    +
    +
  • +
  • + filename + – +
    +
              Destination filename (i.e. test.txt).
    +
    +
    +
  • +
  • + write_flag + – +
    +
              Defines the interaction with the file. 
    +          The default is "w" (overwrite any existing content).
    +          For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/file.py +
def write_to_text_file(content, filename, write_flag = 'w'):
+    """
+    Defininition to write a Pythonic list to a text file.
+
+
+    Parameters
+    ----------
+    content         : list
+                      Pythonic string list to be written to a file.
+    filename        : str
+                      Destination filename (i.e. test.txt).
+    write_flag      : str
+                      Defines the interaction with the file. 
+                      The default is "w" (overwrite any existing content).
+                      For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files
+    """
+    with open(expanduser(filename), write_flag) as f:
+        for line in content:
+            f.write('{}\n'.format(line))
+    return True
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + +

A class to work with latex documents.

+ + + + + + +
+ Source code in odak/tools/latex.py +
  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
class latex():
+    """
+    A class to work with latex documents.
+    """
+    def __init__(
+                 self,
+                 filename
+                ):
+        """
+        Parameters
+        ----------
+        filename     : str
+                       Source filename (i.e. sample.tex).
+        """
+        self.filename = filename
+        self.content = read_text_file(self.filename)
+        self.content_type = []
+        self.latex_dictionary = [
+                                 '\\documentclass',
+                                 '\\if',
+                                 '\\pdf',
+                                 '\\else',
+                                 '\\fi',
+                                 '\\vgtc',
+                                 '\\teaser',
+                                 '\\abstract',
+                                 '\\CCS',
+                                 '\\usepackage',
+                                 '\\PassOptionsToPackage',
+                                 '\\definecolor',
+                                 '\\AtBeginDocument',
+                                 '\\providecommand',
+                                 '\\setcopyright',
+                                 '\\copyrightyear',
+                                 '\\acmYear',
+                                 '\\citestyle',
+                                 '\\newcommand',
+                                 '\\acmDOI',
+                                 '\\newabbreviation',
+                                 '\\global',
+                                 '\\begin{document}',
+                                 '\\author',
+                                 '\\affiliation',
+                                 '\\email',
+                                 '\\institution',
+                                 '\\streetaddress',
+                                 '\\city',
+                                 '\\country',
+                                 '\\postcode',
+                                 '\\ccsdesc',
+                                 '\\received',
+                                 '\\includegraphics',
+                                 '\\caption',
+                                 '\\centering',
+                                 '\\label',
+                                 '\\maketitle',
+                                 '\\toprule',
+                                 '\\multirow',
+                                 '\\multicolumn',
+                                 '\\cmidrule',
+                                 '\\addlinespace',
+                                 '\\midrule',
+                                 '\\cellcolor',
+                                 '\\bibliography',
+                                 '}',
+                                 '\\title',
+                                 '</ccs2012>',
+                                 '\\bottomrule',
+                                 '<concept>',
+                                 '<concept',
+                                 '<ccs',
+                                 '\\item',
+                                 '</concept',
+                                 '\\begin{abstract}',
+                                 '\\end{abstract}',
+                                 '\\endinput',
+                                 '\\\\'
+                                ]
+        self.latex_begin_dictionary = [
+                                       '\\begin{figure}',
+                                       '\\begin{figure*}',
+                                       '\\begin{equation}',
+                                       '\\begin{CCSXML}',
+                                       '\\begin{teaserfigure}',
+                                       '\\begin{table*}',
+                                       '\\begin{table}',
+                                       '\\begin{gather}',
+                                       '\\begin{align}',
+                                      ]
+        self.latex_end_dictionary = [
+                                     '\\end{figure}',
+                                     '\\end{figure*}',
+                                     '\\end{equation}',
+                                     '\\end{CCSXML}',
+                                     '\\end{teaserfigure}',
+                                     '\\end{table*}',
+                                     '\\end{table}',
+                                     '\\end{gather}',
+                                     '\\end{align}',
+                                    ]
+        self._label_lines()
+
+
+    def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+        """
+        Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+        Parameters
+        ----------
+        begin_dictionary     : list
+                               Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).
+        end_dictionary       : list
+                               Pythonic list containing latex syntax for end commands (i.e. \\end{table}).
+        syntax_dictionary    : list
+                               Pythonic list containing latex syntax (i.e. \\item).
+
+        """
+        self.latex_begin_dictionary = begin_dictionary
+        self.latex_end_dictionary = end_dictionary
+        self.latex_dictionary = syntax_dictionary
+        self._label_lines
+
+
+    def _label_lines(self):
+        """
+        Internal function for labelling lines.
+        """
+        content_type_flag = False
+        for line_id, line in enumerate(self.content):
+            while len(line) > 0 and line[0] == ' ':
+                 line = line[1::]
+            self.content[line_id] = line
+            if len(line) == 0:
+                content_type = 'empty'
+            elif line[0] == '%':
+                content_type = 'comment'
+            else:
+                content_type = 'text'
+            for syntax in self.latex_begin_dictionary:
+                if line.find(syntax) != -1:
+                    content_type_flag = True
+                    content_type = 'latex'
+            for syntax in self.latex_dictionary:
+                if line.find(syntax) != -1:
+                    content_type = 'latex'
+            if content_type_flag == True:
+                content_type = 'latex'
+                for syntax in self.latex_end_dictionary:
+                    if line.find(syntax) != -1:
+                         content_type_flag = False
+            self.content_type.append(content_type)
+
+
+    def get_line_count(self):
+        """
+        Definition to get the line count.
+
+
+        Returns
+        -------
+        line_count     : int
+                         Number of lines in the loaded latex document.
+        """
+        self.line_count = len(self.content)
+        return self.line_count
+
+
+    def get_line(self, line_id = 0):
+        """
+        Definition to get a specific line by inputting a line nunber.
+
+
+        Returns
+        ----------
+        line           : str
+                         Requested line.
+        content_type   : str
+                         Line's content type (e.g., latex, comment, text).
+        """
+        line = self.content[line_id]
+        content_type = self.content_type[line_id]
+        return line, content_type
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(filename) + +

+ + +
+ + + +

Parameters:

+
    +
  • + filename + – +
    +
           Source filename (i.e. sample.tex).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def __init__(
+             self,
+             filename
+            ):
+    """
+    Parameters
+    ----------
+    filename     : str
+                   Source filename (i.e. sample.tex).
+    """
+    self.filename = filename
+    self.content = read_text_file(self.filename)
+    self.content_type = []
+    self.latex_dictionary = [
+                             '\\documentclass',
+                             '\\if',
+                             '\\pdf',
+                             '\\else',
+                             '\\fi',
+                             '\\vgtc',
+                             '\\teaser',
+                             '\\abstract',
+                             '\\CCS',
+                             '\\usepackage',
+                             '\\PassOptionsToPackage',
+                             '\\definecolor',
+                             '\\AtBeginDocument',
+                             '\\providecommand',
+                             '\\setcopyright',
+                             '\\copyrightyear',
+                             '\\acmYear',
+                             '\\citestyle',
+                             '\\newcommand',
+                             '\\acmDOI',
+                             '\\newabbreviation',
+                             '\\global',
+                             '\\begin{document}',
+                             '\\author',
+                             '\\affiliation',
+                             '\\email',
+                             '\\institution',
+                             '\\streetaddress',
+                             '\\city',
+                             '\\country',
+                             '\\postcode',
+                             '\\ccsdesc',
+                             '\\received',
+                             '\\includegraphics',
+                             '\\caption',
+                             '\\centering',
+                             '\\label',
+                             '\\maketitle',
+                             '\\toprule',
+                             '\\multirow',
+                             '\\multicolumn',
+                             '\\cmidrule',
+                             '\\addlinespace',
+                             '\\midrule',
+                             '\\cellcolor',
+                             '\\bibliography',
+                             '}',
+                             '\\title',
+                             '</ccs2012>',
+                             '\\bottomrule',
+                             '<concept>',
+                             '<concept',
+                             '<ccs',
+                             '\\item',
+                             '</concept',
+                             '\\begin{abstract}',
+                             '\\end{abstract}',
+                             '\\endinput',
+                             '\\\\'
+                            ]
+    self.latex_begin_dictionary = [
+                                   '\\begin{figure}',
+                                   '\\begin{figure*}',
+                                   '\\begin{equation}',
+                                   '\\begin{CCSXML}',
+                                   '\\begin{teaserfigure}',
+                                   '\\begin{table*}',
+                                   '\\begin{table}',
+                                   '\\begin{gather}',
+                                   '\\begin{align}',
+                                  ]
+    self.latex_end_dictionary = [
+                                 '\\end{figure}',
+                                 '\\end{figure*}',
+                                 '\\end{equation}',
+                                 '\\end{CCSXML}',
+                                 '\\end{teaserfigure}',
+                                 '\\end{table*}',
+                                 '\\end{table}',
+                                 '\\end{gather}',
+                                 '\\end{align}',
+                                ]
+    self._label_lines()
+
+
+
+ +
+ +
+ + +

+ get_line(line_id=0) + +

+ + +
+ +

Definition to get a specific line by inputting a line nunber.

+ + +

Returns:

+
    +
  • +line ( str +) – +
    +

    Requested line.

    +
    +
  • +
  • +content_type ( str +) – +
    +

    Line's content type (e.g., latex, comment, text).

    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def get_line(self, line_id = 0):
+    """
+    Definition to get a specific line by inputting a line nunber.
+
+
+    Returns
+    ----------
+    line           : str
+                     Requested line.
+    content_type   : str
+                     Line's content type (e.g., latex, comment, text).
+    """
+    line = self.content[line_id]
+    content_type = self.content_type[line_id]
+    return line, content_type
+
+
+
+ +
+ +
+ + +

+ get_line_count() + +

+ + +
+ +

Definition to get the line count.

+ + +

Returns:

+
    +
  • +line_count ( int +) – +
    +

    Number of lines in the loaded latex document.

    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def get_line_count(self):
+    """
+    Definition to get the line count.
+
+
+    Returns
+    -------
+    line_count     : int
+                     Number of lines in the loaded latex document.
+    """
+    self.line_count = len(self.content)
+    return self.line_count
+
+
+
+ +
+ +
+ + +

+ set_latex_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary) + +

+ + +
+ +

Set document specific dictionaries so that the lines could be labelled in accordance.

+ + +

Parameters:

+
    +
  • + begin_dictionary + – +
    +
                   Pythonic list containing latex syntax for begin commands (i.e. \begin{align}).
    +
    +
    +
  • +
  • + end_dictionary + – +
    +
                   Pythonic list containing latex syntax for end commands (i.e. \end{table}).
    +
    +
    +
  • +
  • + syntax_dictionary + – +
    +
                   Pythonic list containing latex syntax (i.e. \item).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/latex.py +
def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+    """
+    Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+    Parameters
+    ----------
+    begin_dictionary     : list
+                           Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).
+    end_dictionary       : list
+                           Pythonic list containing latex syntax for end commands (i.e. \\end{table}).
+    syntax_dictionary    : list
+                           Pythonic list containing latex syntax (i.e. \\item).
+
+    """
+    self.latex_begin_dictionary = begin_dictionary
+    self.latex_end_dictionary = end_dictionary
+    self.latex_dictionary = syntax_dictionary
+    self._label_lines
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]) + +

+ + +
+ +

A definition to blur a field using a Gaussian kernel.

+ + +

Parameters:

+
    +
  • + field + – +
    +
            MxN field.
    +
    +
    +
  • +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +blurred_field ( ndarray +) – +
    +

    Blurred field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]):
+    """
+    A definition to blur a field using a Gaussian kernel.
+
+    Parameters
+    ----------
+    field         : ndarray
+                    MxN field.
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+
+    Returns
+    ----------
+    blurred_field : ndarray
+                    Blurred field.
+    """
+    kernel = generate_2d_gaussian(kernel_length, nsigma)
+    kernel = zero_pad(kernel, field.shape)
+    blurred_field = convolve2d(field, kernel)
+    blurred_field = blurred_field/np.amax(blurred_field)
+    return blurred_field
+
+
+
+ +
+ +
+ + +

+ convolve2d(field, kernel) + +

+ + +
+ +

Definition to convolve a field with a kernel by multiplying in frequency space.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field with MxN shape.
    +
    +
    +
  • +
  • + kernel + – +
    +
          Input kernel with MxN shape.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( ndarray +) – +
    +

    Convolved field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def convolve2d(field, kernel):
+    """
+    Definition to convolve a field with a kernel by multiplying in frequency space.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field with MxN shape.
+    kernel      : ndarray
+                  Input kernel with MxN shape.
+
+    Returns
+    ----------
+    new_field   : ndarray
+                  Convolved field.
+    """
+    fr = np.fft.fft2(field)
+    fr2 = np.fft.fft2(np.flipud(np.fliplr(kernel)))
+    m, n = fr.shape
+    new_field = np.real(np.fft.ifft2(fr*fr2))
+    new_field = np.roll(new_field, int(-m/2+1), axis=0)
+    new_field = np.roll(new_field, int(-n/2+1), axis=1)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ create_empty_list(dimensions=[1, 1]) + +

+ + +
+ +

A definition to create an empty Pythonic list.

+ + +

Parameters:

+
    +
  • + dimensions + – +
    +
           Dimensions of the list to be created.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_list ( list +) – +
    +

    New empty list.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
def create_empty_list(dimensions = [1, 1]):
+    """
+    A definition to create an empty Pythonic list.
+
+    Parameters
+    ----------
+    dimensions   : list
+                   Dimensions of the list to be created.
+
+    Returns
+    -------
+    new_list     : list
+                   New empty list.
+    """
+    new_list = 0
+    for n in reversed(dimensions):
+        new_list = [new_list] * n
+    return new_list
+
+
+
+ +
+ +
+ + +

+ crop_center(field, size=None) + +

+ + +
+ +

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field 2Mx2N array.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +cropped ( ndarray +) – +
    +

    Cropped version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def crop_center(field, size=None):
+    """
+    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field 2Mx2N array.
+
+    Returns
+    ----------
+    cropped     : ndarray
+                  Cropped version of the input field.
+    """
+    if type(size) == type(None):
+        qx = int(np.ceil(field.shape[0])/4)
+        qy = int(np.ceil(field.shape[1])/4)
+        cropped = np.copy(field[qx:3*qx, qy:3*qy])
+    else:
+        cx = int(np.ceil(field.shape[0]/2))
+        cy = int(np.ceil(field.shape[1]/2))
+        hx = int(np.ceil(size[0]/2))
+        hy = int(np.ceil(size[1]/2))
+        cropped = np.copy(field[cx-hx:cx+hx, cy-hy:cy+hy])
+    return cropped
+
+
+
+ +
+ +
+ + +

+ generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]) + +

+ + +
+ +

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

+ + +

Parameters:

+
    +
  • + kernel_length + (list, default: + [21, 21] +) + – +
    +
            Length of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
  • + nsigma + – +
    +
            Sigma of the Gaussian kernel along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +kernel_2d ( ndarray +) – +
    +

    Generated Gaussian kernel.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]):
+    """
+    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy
+
+    Parameters
+    ----------
+    kernel_length : list
+                    Length of the Gaussian kernel along X and Y axes.
+    nsigma        : list
+                    Sigma of the Gaussian kernel along X and Y axes.
+
+    Returns
+    ----------
+    kernel_2d     : ndarray
+                    Generated Gaussian kernel.
+    """
+    x = np.linspace(-nsigma[0], nsigma[0], kernel_length[0]+1)
+    y = np.linspace(-nsigma[1], nsigma[1], kernel_length[1]+1)
+    xx, yy = np.meshgrid(x, y)
+    kernel_2d = np.exp(-0.5*(np.square(xx) /
+                       np.square(nsigma[0]) + np.square(yy)/np.square(nsigma[1])))
+    kernel_2d = kernel_2d/kernel_2d.sum()
+    return kernel_2d
+
+
+
+ +
+ +
+ + +

+ generate_bandlimits(size=[512, 512], levels=9) + +

+ + +
+ +

A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.

+ + +

Parameters:

+
    +
  • + size + – +
    +
         Size of each mask in octaves.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +masks ( ndarray +) – +
    +

    Masks (Octaves).

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def generate_bandlimits(size=[512, 512], levels=9):
+    """
+    A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.
+
+    Parameters
+    ----------
+    size       : list
+                 Size of each mask in octaves.
+
+    Returns
+    ----------
+    masks      : ndarray
+                 Masks (Octaves).
+    """
+    masks = np.zeros((levels, size[0], size[1]))
+    cx = int(size[0]/2)
+    cy = int(size[1]/2)
+    for i in range(0, masks.shape[0]):
+        deltax = int((size[0])/(2**(i+1)))
+        deltay = int((size[1])/(2**(i+1)))
+        masks[
+            i,
+            cx-deltax:cx+deltax,
+            cy-deltay:cy+deltay
+        ] = 1.
+        masks[
+            i,
+            int(cx-deltax/2.):int(cx+deltax/2.),
+            int(cy-deltay/2.):int(cy+deltay/2.)
+        ] = 0.
+    masks = np.asarray(masks)
+    return masks
+
+
+
+ +
+ +
+ + +

+ nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Size.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Inverse NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list
+                  Size.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  Inverse NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    result = finufft.nufft2d2(
+        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)
+    if type(size) == type(None):
+        result = result.reshape(field.shape)
+    else:
+        result = result.reshape(size)
+    return result
+
+
+
+ +
+ +
+ + +

+ nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Shape of the NUFFT calculated for an input field.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list or ndarray
+                  Shape of the NUFFT calculated for an input field.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    if type(size) == type(None):
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            image.shape,
+            eps=eps,
+            isign=sign
+        )
+    else:
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            (size[0], size[1]),
+            eps=eps,
+            isign=sign
+        )
+    result = np.asarray(result)
+    return result
+
+
+
+ +
+ +
+ + +

+ quantize(image_field, bits=4) + +

+ + +
+ +

Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.

+ + +

Parameters:

+
    +
  • + image_field + (ndarray) + – +
    +
          Input image field.
    +
    +
    +
  • +
  • + bits + – +
    +
          A value in between 0 to 8. Can not be zero.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( ndarray +) – +
    +

    Quantized image field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def quantize(image_field, bits=4):
+    """
+    Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.
+
+    Parameters
+    ----------
+    image_field : ndarray
+                  Input image field.
+    bits        : int
+                  A value in between 0 to 8. Can not be zero.
+
+    Returns
+    ----------
+    new_field   : ndarray
+                  Quantized image field.
+    """
+    divider = 2**(8-bits)
+    new_field = image_field/divider
+    new_field = new_field.astype(np.int64)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ zero_pad(field, size=None, method='center') + +

+ + +
+ +

Definition to zero pad a MxN array to 2Mx2N array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
                Input field MxN array.
    +
    +
    +
  • +
  • + size + – +
    +
                Size to be zeropadded.
    +
    +
    +
  • +
  • + method + – +
    +
                Zeropad either by placing the content to center or to the left.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field_zero_padded ( ndarray +) – +
    +

    Zeropadded version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def zero_pad(field, size=None, method='center'):
+    """
+    Definition to zero pad a MxN array to 2Mx2N array.
+
+    Parameters
+    ----------
+    field             : ndarray
+                        Input field MxN array.
+    size              : list
+                        Size to be zeropadded.
+    method            : str
+                        Zeropad either by placing the content to center or to the left.
+
+    Returns
+    ----------
+    field_zero_padded : ndarray
+                        Zeropadded version of the input field.
+    """
+    if type(size) == type(None):
+        hx = int(np.ceil(field.shape[0])/2)
+        hy = int(np.ceil(field.shape[1])/2)
+    else:
+        hx = int(np.ceil((size[0]-field.shape[0])/2))
+        hy = int(np.ceil((size[1]-field.shape[1])/2))
+    if method == 'center':
+        field_zero_padded = np.pad(
+            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))
+    elif method == 'left aligned':
+        field_zero_padded = np.pad(
+            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))
+    if type(size) != type(None):
+        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]
+    return field_zero_padded
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + +

A class to work with markdown documents.

+ + + + + + +
+ Source code in odak/tools/markdown.py +
  4
+  5
+  6
+  7
+  8
+  9
+ 10
+ 11
+ 12
+ 13
+ 14
+ 15
+ 16
+ 17
+ 18
+ 19
+ 20
+ 21
+ 22
+ 23
+ 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
class markdown():
+    """
+    A class to work with markdown documents.
+    """
+    def __init__(
+                 self,
+                 filename
+                ):
+        """
+        Parameters
+        ----------
+        filename     : str
+                       Source filename (i.e. sample.md).
+        """
+        self.filename = filename
+        self.content = read_text_file(self.filename)
+        self.content_type = []
+        self.markdown_dictionary = [
+                                     '#',
+                                   ]
+        self.markdown_begin_dictionary = [
+                                          '```bash',
+                                          '```python',
+                                          '```',
+                                         ]
+        self.markdown_end_dictionary = [
+                                        '```',
+                                       ]
+        self._label_lines()
+
+
+    def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+        """
+        Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+        Parameters
+        ----------
+        begin_dictionary     : list
+                               Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
+        end_dictionary       : list
+                               Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
+        syntax_dictionary    : list
+                               Pythonic list containing markdown syntax (i.e. \\item).
+
+        """
+        self.markdown_begin_dictionary = begin_dictionary
+        self.markdown_end_dictionary = end_dictionary
+        self.markdown_dictionary = syntax_dictionary
+        self._label_lines
+
+
+    def _label_lines(self):
+        """
+        Internal function for labelling lines.
+        """
+        content_type_flag = False
+        for line_id, line in enumerate(self.content):
+            while len(line) > 0 and line[0] == ' ':
+                 line = line[1::]
+            self.content[line_id] = line
+            if len(line) == 0:
+                content_type = 'empty'
+            elif line[0] == '%':
+                content_type = 'comment'
+            else:
+                content_type = 'text'
+            for syntax in self.markdown_begin_dictionary:
+                if line.find(syntax) != -1:
+                    content_type_flag = True
+                    content_type = 'markdown'
+            for syntax in self.markdown_dictionary:
+                if line.find(syntax) != -1:
+                    content_type = 'markdown'
+            if content_type_flag == True:
+                content_type = 'markdown'
+                for syntax in self.markdown_end_dictionary:
+                    if line.find(syntax) != -1:
+                         content_type_flag = False
+            self.content_type.append(content_type)
+
+
+    def get_line_count(self):
+        """
+        Definition to get the line count.
+
+
+        Returns
+        -------
+        line_count     : int
+                         Number of lines in the loaded markdown document.
+        """
+        self.line_count = len(self.content)
+        return self.line_count
+
+
+    def get_line(self, line_id = 0):
+        """
+        Definition to get a specific line by inputting a line nunber.
+
+
+        Returns
+        ----------
+        line           : str
+                         Requested line.
+        content_type   : str
+                         Line's content type (e.g., markdown, comment, text).
+        """
+        line = self.content[line_id]
+        content_type = self.content_type[line_id]
+        return line, content_type
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(filename) + +

+ + +
+ + + +

Parameters:

+
    +
  • + filename + – +
    +
           Source filename (i.e. sample.md).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def __init__(
+             self,
+             filename
+            ):
+    """
+    Parameters
+    ----------
+    filename     : str
+                   Source filename (i.e. sample.md).
+    """
+    self.filename = filename
+    self.content = read_text_file(self.filename)
+    self.content_type = []
+    self.markdown_dictionary = [
+                                 '#',
+                               ]
+    self.markdown_begin_dictionary = [
+                                      '```bash',
+                                      '```python',
+                                      '```',
+                                     ]
+    self.markdown_end_dictionary = [
+                                    '```',
+                                   ]
+    self._label_lines()
+
+
+
+ +
+ +
+ + +

+ get_line(line_id=0) + +

+ + +
+ +

Definition to get a specific line by inputting a line nunber.

+ + +

Returns:

+
    +
  • +line ( str +) – +
    +

    Requested line.

    +
    +
  • +
  • +content_type ( str +) – +
    +

    Line's content type (e.g., markdown, comment, text).

    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
def get_line(self, line_id = 0):
+    """
+    Definition to get a specific line by inputting a line nunber.
+
+
+    Returns
+    ----------
+    line           : str
+                     Requested line.
+    content_type   : str
+                     Line's content type (e.g., markdown, comment, text).
+    """
+    line = self.content[line_id]
+    content_type = self.content_type[line_id]
+    return line, content_type
+
+
+
+ +
+ +
+ + +

+ get_line_count() + +

+ + +
+ +

Definition to get the line count.

+ + +

Returns:

+
    +
  • +line_count ( int +) – +
    +

    Number of lines in the loaded markdown document.

    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
def get_line_count(self):
+    """
+    Definition to get the line count.
+
+
+    Returns
+    -------
+    line_count     : int
+                     Number of lines in the loaded markdown document.
+    """
+    self.line_count = len(self.content)
+    return self.line_count
+
+
+
+ +
+ +
+ + +

+ set_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary) + +

+ + +
+ +

Set document specific dictionaries so that the lines could be labelled in accordance.

+ + +

Parameters:

+
    +
  • + begin_dictionary + – +
    +
                   Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
    +
    +
    +
  • +
  • + end_dictionary + – +
    +
                   Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
    +
    +
    +
  • +
  • + syntax_dictionary + – +
    +
                   Pythonic list containing markdown syntax (i.e. \item).
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/markdown.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):
+    """
+    Set document specific dictionaries so that the lines could be labelled in accordance.
+
+
+    Parameters
+    ----------
+    begin_dictionary     : list
+                           Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).
+    end_dictionary       : list
+                           Pythonic list containing markdown syntax for end of blocks (e.g., code, html).
+    syntax_dictionary    : list
+                           Pythonic list containing markdown syntax (i.e. \\item).
+
+    """
+    self.markdown_begin_dictionary = begin_dictionary
+    self.markdown_end_dictionary = end_dictionary
+    self.markdown_dictionary = syntax_dictionary
+    self._label_lines
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ batch_of_rays(entry, exit) + +

+ + +
+ +

Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.

+ + +

Parameters:

+
    +
  • + entry + – +
    +
         Either a single point with size of 3 or multiple points with the size of nx3.
    +
    +
    +
  • +
  • + exit + – +
    +
         Either a single point with size of 3 or multiple points with the size of nx3.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rays ( ndarray +) – +
    +

    Generated batch of rays.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def batch_of_rays(entry, exit):
+    """
+    Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.
+
+    Parameters
+    ----------
+    entry      : ndarray
+                 Either a single point with size of 3 or multiple points with the size of nx3.
+    exit       : ndarray
+                 Either a single point with size of 3 or multiple points with the size of nx3.
+
+    Returns
+    ----------
+    rays       : ndarray
+                 Generated batch of rays.
+    """
+    norays = np.array([0, 0])
+    if len(entry.shape) == 1:
+        entry = entry.reshape((1, 3))
+    if len(exit.shape) == 1:
+        exit = exit.reshape((1, 3))
+    norays = np.amax(np.asarray([entry.shape[0], exit.shape[0]]))
+    if norays > exit.shape[0]:
+        exit = np.repeat(exit, norays, axis=0)
+    elif norays > entry.shape[0]:
+        entry = np.repeat(entry, norays, axis=0)
+    rays = []
+    norays = int(norays)
+    for i in range(norays):
+        rays.append(
+            create_ray_from_two_points(
+                entry[i],
+                exit[i]
+            )
+        )
+    rays = np.asarray(rays)
+    return rays
+
+
+
+ +
+ +
+ + +

+ box_volume_sample(no=[10, 10, 10], size=[100.0, 100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples in a box volume.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the volume.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the volume.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the volume.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def box_volume_sample(no=[10, 10, 10], size=[100., 100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples in a box volume.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the volume.
+    center      : list
+                  Center location of the volume.
+    angles      : list
+                  Tilt of the volume.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], no[2], 3))
+    x, y, z = np.mgrid[0:no[0], 0:no[1], 0:no[2]]
+    step = [
+        size[0]/no[0],
+        size[1]/no[1],
+        size[2]/no[2]
+    ]
+    samples[:, :, :, 0] = x*step[0]+step[0]/2.-size[0]/2.
+    samples[:, :, :, 1] = y*step[1]+step[1]/2.-size[1]/2.
+    samples[:, :, :, 2] = z*step[2]+step[2]/2.-size[2]/2.
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1]*samples.shape[2], samples.shape[3]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ circular_sample(no=[10, 10], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples inside a circle over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_sample(no=[10, 10], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples inside a circle over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0]+1, no[1]+1, 3))
+    r_angles, r = np.mgrid[0:no[0]+1, 0:no[1]+1]
+    r = r/np.amax(r)*radius
+    r_angles = r_angles/np.amax(r_angles)*np.pi*2
+    samples[:, :, 0] = r*np.cos(r_angles)
+    samples[:, :, 1] = r*np.sin(r_angles)
+    samples = samples[1:no[0]+1, 1:no[1]+1, :]
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1], samples.shape[2]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ circular_uniform_random_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate sample inside a circle uniformly but randomly.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_uniform_random_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """ 
+    Definition to generate sample inside a circle uniformly but randomly.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.empty((0, 3))
+    rs = np.sqrt(np.random.uniform(0, 1, no[0]))
+    angs = np.random.uniform(0, 2*np.pi, no[1])
+    for i in rs:
+        for angle in angs:
+            r = radius*i
+            point = np.array(
+                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])
+            samples = np.vstack((samples, point))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ circular_uniform_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate sample inside a circle uniformly.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of the circle.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def circular_uniform_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate sample inside a circle uniformly.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of the circle.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.empty((0, 3))
+    for i in range(0, no[0]):
+        r = i/no[0]*radius
+        ang_no = no[1]*i/no[0]
+        for j in range(0, int(no[1]*i/no[0])):
+            angle = j/ang_no*2*np.pi
+            point = np.array(
+                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])
+            samples = np.vstack((samples, point))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0]) + +

+ + +
+ +

Definition to generate samples over a surface.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + size + – +
    +
          Physical size of the surface.
    +
    +
    +
  • +
  • + center + – +
    +
          Center location of the surface.
    +
    +
    +
  • +
  • + angles + – +
    +
          Tilt of the surface.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
def grid_sample(no=[10, 10], size=[100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):
+    """
+    Definition to generate samples over a surface.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    size        : list
+                  Physical size of the surface.
+    center      : list
+                  Center location of the surface.
+    angles      : list
+                  Tilt of the surface.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    step = [
+        size[0]/(no[0]-1),
+        size[1]/(no[1]-1)
+    ]
+    x, y = np.mgrid[0:no[0], 0:no[1]]
+    samples[:, :, 0] = x*step[0]-size[0]/2.
+    samples[:, :, 1] = y*step[1]-size[1]/2.
+    samples = samples.reshape(
+        (samples.shape[0]*samples.shape[1], samples.shape[2]))
+    samples = rotate_points(samples, angles=angles, offset=center)
+    return samples
+
+
+
+ +
+ +
+ + +

+ random_sample_point_cloud(point_cloud, no, p=None) + +

+ + +
+ +

Definition to pull a subset of points from a point cloud with a given probability.

+ + +

Parameters:

+
    +
  • + point_cloud + – +
    +
           Point cloud array.
    +
    +
    +
  • +
  • + no + – +
    +
           Number of samples.
    +
    +
    +
  • +
  • + p + – +
    +
           Probability list in the same size as no.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +subset ( ndarray +) – +
    +

    Subset of the given point cloud.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
def random_sample_point_cloud(point_cloud, no, p=None):
+    """
+    Definition to pull a subset of points from a point cloud with a given probability.
+
+    Parameters
+    ----------
+    point_cloud  : ndarray
+                   Point cloud array.
+    no           : list
+                   Number of samples.
+    p            : list
+                   Probability list in the same size as no.
+
+    Returns
+    ----------
+    subset       : ndarray
+                   Subset of the given point cloud.
+    """
+    choice = np.random.choice(point_cloud.shape[0], no, p)
+    subset = point_cloud[choice, :]
+    return subset
+
+
+
+ +
+ +
+ + +

+ sphere_sample(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2]) + +

+ + +
+ +

Definition to generate a regular sample set on the surface of a sphere using polar coordinates.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of a sphere.
    +
    +
    +
  • +
  • + center + – +
    +
          Center of a sphere.
    +
    +
    +
  • +
  • + k + – +
    +
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
def sphere_sample(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):
+    """
+    Definition to generate a regular sample set on the surface of a sphere using polar coordinates.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of a sphere.
+    center      : list
+                  Center of a sphere.
+    k           : list
+                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    psi, teta = np.mgrid[0:no[0], 0:no[1]]
+    psi = k[0]*np.pi/no[0]*psi
+    teta = k[1]*np.pi/no[1]*teta
+    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)
+    samples[:, :, 1] = center[0]+radius*np.sin(psi)*np.sin(teta)
+    samples[:, :, 2] = center[0]+radius*np.cos(psi)
+    samples = samples.reshape((no[0]*no[1], 3))
+    return samples
+
+
+
+ +
+ +
+ + +

+ sphere_sample_uniform(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2]) + +

+ + +
+ +

Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.

+ + +

Parameters:

+
    +
  • + no + – +
    +
          Number of samples.
    +
    +
    +
  • +
  • + radius + – +
    +
          Radius of a sphere.
    +
    +
    +
  • +
  • + center + – +
    +
          Center of a sphere.
    +
    +
    +
  • +
  • + k + – +
    +
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +samples ( ndarray +) – +
    +

    Samples generated.

    +
    +
  • +
+ +
+ Source code in odak/tools/sample.py +
60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
def sphere_sample_uniform(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):
+    """
+    Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.
+
+    Parameters
+    ----------
+    no          : list
+                  Number of samples.
+    radius      : float
+                  Radius of a sphere.
+    center      : list
+                  Center of a sphere.
+    k           : list
+                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.
+
+
+    Returns
+    ----------
+    samples     : ndarray
+                  Samples generated.
+
+    """
+    samples = np.zeros((no[0], no[1], 3))
+    row = np.arange(0, no[0])
+    psi, teta = np.mgrid[0:no[0], 0:no[1]]
+    for psi_id in range(0, no[0]):
+        psi[psi_id] = np.roll(row, psi_id, axis=0)
+        teta[psi_id] = np.roll(row, -psi_id, axis=0)
+    psi = k[0]*np.pi/no[0]*psi
+    teta = k[1]*np.pi/no[1]*teta
+    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)
+    samples[:, :, 1] = center[1]+radius*np.sin(psi)*np.sin(teta)
+    samples[:, :, 2] = center[2]+radius*np.cos(psi)
+    samples = samples.reshape((no[0]*no[1], 3))
+    return samples
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ closest_point_to_a_ray(point, ray) + +

+ + +
+ +

Definition to calculate the point on a ray that is closest to given point.

+ + +

Parameters:

+
    +
  • + point + – +
    +
            Given point in X,Y,Z.
    +
    +
    +
  • +
  • + ray + – +
    +
            Given ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +closest_point ( ndarray +) – +
    +

    Calculated closest point.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def closest_point_to_a_ray(point, ray):
+    """
+    Definition to calculate the point on a ray that is closest to given point.
+
+    Parameters
+    ----------
+    point         : list
+                    Given point in X,Y,Z.
+    ray           : ndarray
+                    Given ray.
+
+    Returns
+    ---------
+    closest_point : ndarray
+                    Calculated closest point.
+    """
+    from odak.raytracing import propagate_a_ray
+    if len(ray.shape) == 2:
+        ray = ray.reshape((1, 2, 3))
+    p0 = ray[:, 0]
+    p1 = propagate_a_ray(ray, 1.)
+    if len(p1.shape) == 2:
+        p1 = p1.reshape((1, 2, 3))
+    p1 = p1[:, 0]
+    p1 = p1.reshape(3)
+    p0 = p0.reshape(3)
+    point = point.reshape(3)
+    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)
+    closest_point = propagate_a_ray(ray, closest_distance)[0]
+    return closest_point
+
+
+
+ +
+ +
+ + +

+ cross_product(vector1, vector2) + +

+ + +
+ +

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

+ + +

Parameters:

+
    +
  • + vector1 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
  • + vector2 + – +
    +
           A vector/ray.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +ray ( ndarray +) – +
    +

    Array that contains starting points and cosines of a created ray.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def cross_product(vector1, vector2):
+    """
+    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product
+
+    Parameters
+    ----------
+    vector1      : ndarray
+                   A vector/ray.
+    vector2      : ndarray
+                   A vector/ray.
+
+    Returns
+    ----------
+    ray          : ndarray
+                   Array that contains starting points and cosines of a created ray.
+    """
+    angle = np.cross(vector1[1].T, vector2[1].T)
+    angle = np.asarray(angle)
+    ray = np.array([vector1[0], angle], dtype=np.float32)
+    return ray
+
+
+
+ +
+ +
+ + +

+ distance_between_point_clouds(points0, points1) + +

+ + +
+ +

A definition to find distance between every point in one cloud to other points in the other point cloud.

+ + +

Parameters:

+
    +
  • + points0 + – +
    +
          Mx3 points.
    +
    +
    +
  • +
  • + points1 + – +
    +
          Nx3 points.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distances ( ndarray +) – +
    +

    MxN distances.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
def distance_between_point_clouds(points0, points1):
+    """
+    A definition to find distance between every point in one cloud to other points in the other point cloud.
+    Parameters
+    ----------
+    points0     : ndarray
+                  Mx3 points.
+    points1     : ndarray
+                  Nx3 points.
+
+    Returns
+    ----------
+    distances   : ndarray
+                  MxN distances.
+    """
+    c = points1.reshape((1, points1.shape[0], points1.shape[1]))
+    a = np.repeat(c, points0.shape[0], axis=0)
+    b = points0.reshape((points0.shape[0], 1, points0.shape[1]))
+    b = np.repeat(b, a.shape[1], axis=1)
+    distances = np.sqrt(np.sum((a-b)**2, axis=2))
+    return distances
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : list
+                  First point in X,Y,Z.
+    point2      : list
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : float
+                  Distance in between given two points.
+    """
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = np.sqrt(np.sum((point1-point2)**2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))
+    return distance
+
+
+
+ +
+ +
+ + +

+ point_to_ray_distance(point, ray_point_0, ray_point_1) + +

+ + +
+ +

Definition to find point's closest distance to a line represented with two points.

+ + +

Parameters:

+
    +
  • + point + – +
    +
          Point to be tested.
    +
    +
    +
  • +
  • + ray_point_0 + (ndarray) + – +
    +
          First point to represent a line.
    +
    +
    +
  • +
  • + ray_point_1 + (ndarray) + – +
    +
          Second point to represent a line.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Calculated distance.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
def point_to_ray_distance(point, ray_point_0, ray_point_1):
+    """
+    Definition to find point's closest distance to a line represented with two points.
+
+    Parameters
+    ----------
+    point       : ndarray
+                  Point to be tested.
+    ray_point_0 : ndarray
+                  First point to represent a line.
+    ray_point_1 : ndarray
+                  Second point to represent a line.
+
+    Returns
+    ----------
+    distance    : float
+                  Calculated distance.
+    """
+    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))
+                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)
+    return distance
+
+
+
+ +
+ +
+ + +

+ same_side(p1, p2, a, b) + +

+ + +
+ +

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

+ + +

Parameters:

+
    +
  • + p1 + – +
    +
          Point(s) to check.
    +
    +
    +
  • +
  • + p2 + – +
    +
          This is the point check against.
    +
    +
    +
  • +
  • + a + – +
    +
          First point that forms the line.
    +
    +
    +
  • +
  • + b + – +
    +
          Second point that forms the line.
    +
    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def same_side(p1, p2, a, b):
+    """
+    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.
+
+    Parameters
+    ----------
+    p1          : list
+                  Point(s) to check.
+    p2          : list
+                  This is the point check against.
+    a           : list
+                  First point that forms the line.
+    b           : list
+                  Second point that forms the line.
+    """
+    ba = np.subtract(b, a)
+    p1a = np.subtract(p1, a)
+    p2a = np.subtract(p2, a)
+    cp1 = np.cross(ba, p1a)
+    cp2 = np.cross(ba, p2a)
+    test = np.dot(cp1, cp2)
+    if len(p1.shape) > 1:
+        return test >= 0
+    if test >= 0:
+        return True
+    return False
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

+ + +

Parameters:

+
    +
  • + point + – +
    +
           A point.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
  • +rotx ( ndarray +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
  • +roty ( ndarray +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
  • +rotz ( ndarray +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.
+
+    Parameters
+    ----------
+    point        : ndarray
+                   A point.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation
+    rotx         : ndarray
+                   Rotation matrix along X axis.
+    roty         : ndarray
+                   Rotation matrix along Y axis.
+    rotz         : ndarray
+                   Rotation matrix along Z axis.
+    """
+    point = np.asarray(point)
+    point -= np.asarray(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))
+    result += np.asarray(origin)
+    result += np.asarray(offset)
+    return result, rotx, roty, rotz
+
+
+
+ +
+ +
+ + +

+ rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0]) + +

+ + +
+ +

Definition to rotate points.

+ + +

Parameters:

+
    +
  • + points + – +
    +
           Points.
    +
    +
    +
  • +
  • + angles + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
  • + mode + – +
    +
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
    +
    +
    +
  • +
  • + origin + – +
    +
           Reference point for a rotation.
    +
    +
    +
  • +
  • + offset + – +
    +
           Shift with the given offset.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Result of the rotation

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):
+    """
+    Definition to rotate points.
+
+    Parameters
+    ----------
+    points       : ndarray
+                   Points.
+    angles       : list
+                   Rotation angles in degrees. 
+    mode         : str
+                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.
+    origin       : list
+                   Reference point for a rotation.
+    offset       : list
+                   Shift with the given offset.
+
+    Returns
+    ----------
+    result       : ndarray
+                   Result of the rotation   
+    """
+    points = np.asarray(points)
+    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:
+        result = np.array(offset) + points
+        return result
+    points -= np.array(origin)
+    rotx = rotmatx(angles[0])
+    roty = rotmaty(angles[1])
+    rotz = rotmatz(angles[2])
+    if mode == 'XYZ':
+        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T
+    elif mode == 'XZY':
+        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T
+    elif mode == 'YXZ':
+        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T
+    elif mode == 'ZXY':
+        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T
+    elif mode == 'ZYX':
+        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T
+    result += np.array(origin)
+    result += np.array(offset)
+    return result
+
+
+
+ +
+ +
+ + +

+ rotmatx(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along X axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotx ( ndarray +) – +
    +

    Rotation matrix along X axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
def rotmatx(angle):
+    """
+    Definition to generate a rotation matrix along X axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    rotx         : ndarray
+                   Rotation matrix along X axis.
+    """
+    angle = np.float64(angle)
+    angle = np.radians(angle)
+    rotx = np.array([
+        [1.,               0.,               0.],
+        [0.,  math.cos(angle), -math.sin(angle)],
+        [0.,  math.sin(angle),  math.cos(angle)]
+    ], dtype=np.float64)
+    return rotx
+
+
+
+ +
+ +
+ + +

+ rotmaty(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Y axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +roty ( ndarray +) – +
    +

    Rotation matrix along Y axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
def rotmaty(angle):
+    """
+    Definition to generate a rotation matrix along Y axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    roty         : ndarray
+                   Rotation matrix along Y axis.
+    """
+    angle = np.radians(angle)
+    roty = np.array([
+        [math.cos(angle),  0., math.sin(angle)],
+        [0.,               1.,              0.],
+        [-math.sin(angle), 0., math.cos(angle)]
+    ], dtype=np.float64)
+    return roty
+
+
+
+ +
+ +
+ + +

+ rotmatz(angle) + +

+ + +
+ +

Definition to generate a rotation matrix along Z axis.

+ + +

Parameters:

+
    +
  • + angle + – +
    +
           Rotation angles in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +rotz ( ndarray +) – +
    +

    Rotation matrix along Z axis.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
def rotmatz(angle):
+    """
+    Definition to generate a rotation matrix along Z axis.
+
+    Parameters
+    ----------
+    angle        : list
+                   Rotation angles in degrees.
+
+    Returns
+    -------
+    rotz         : ndarray
+                   Rotation matrix along Z axis.
+    """
+    angle = np.radians(angle)
+    rotz = np.array([
+        [math.cos(angle), -math.sin(angle), 0.],
+        [math.sin(angle),  math.cos(angle), 0.],
+        [0.,               0., 1.]
+    ], dtype=np.float64)
+
+    return rotz
+
+
+
+ +
+ +
+ + +

+ tilt_towards(location, lookat) + +

+ + +
+ +

Definition to tilt surface normal of a plane towards a point.

+ + +

Parameters:

+
    +
  • + location + – +
    +
           Center of the plane to be tilted.
    +
    +
    +
  • +
  • + lookat + – +
    +
           Tilt towards this point.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +angles ( list +) – +
    +

    Rotation angles in degrees.

    +
    +
  • +
+ +
+ Source code in odak/tools/transformation.py +
def tilt_towards(location, lookat):
+    """
+    Definition to tilt surface normal of a plane towards a point.
+
+    Parameters
+    ----------
+    location     : list
+                   Center of the plane to be tilted.
+    lookat       : list
+                   Tilt towards this point.
+
+    Returns
+    ----------
+    angles       : list
+                   Rotation angles in degrees.
+    """
+    dx = location[0]-lookat[0]
+    dy = location[1]-lookat[1]
+    dz = location[2]-lookat[2]
+    dist = np.sqrt(dx**2+dy**2+dz**2)
+    phi = np.arctan2(dy, dx)
+    theta = np.arccos(dz/dist)
+    angles = [
+        0,
+        np.degrees(theta).tolist(),
+        np.degrees(phi).tolist()
+    ]
+    return angles
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/odak/wave/index.html b/odak/wave/index.html new file mode 100644 index 00000000..ac7faffa --- /dev/null +++ b/odak/wave/index.html @@ -0,0 +1,11095 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + odak.wave - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

odak.wave

+ +
+ + + + +
+ +

odak.wave

+

Provides necessary definitions for merging geometric optics with wave theory and classical approaches in the wave theory as well. +See "Introduction to Fourier Optcs" from Joseph Goodman for the theoratical explanation.

+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product." Optics Letters 45.16 (2020): 4416-4419.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product." Optics Letters 45.16 (2020): 4416-4419.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    iflag = -1
+    eps = 10**(-12)
+    nv, nu = field.shape
+    l = nu*dx
+    x = np.linspace(-l/2, l/2, nu)
+    y = np.linspace(-l/2, l/2, nv)
+    X, Y = np.meshgrid(x, y)
+    fx = np.linspace(-1./2./dx, 1./2./dx, nu)
+    fy = np.linspace(-1./2./dx, 1./2./dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    forig = 1./2./dx
+    fc2 = 1./2*(nu/wavelength/np.abs(distance))**0.5
+    ss = np.abs(fc2)/forig
+    zc = nu*dx**2/wavelength
+    K = nu/2/np.amax(np.abs(fx))
+    m = 2
+    nnu2 = m*nu
+    nnv2 = m*nv
+    fxn = np.linspace(-1./2./dx, 1./2./dx, nnu2)
+    fyn = np.linspace(-1./2./dx, 1./2./dx, nnv2)
+    if np.abs(distance) > zc*2:
+        fxn = fxn*ss
+        fyn = fyn*ss
+    FXN, FYN = np.meshgrid(fxn, fyn)
+    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)
+    FX = FXN/np.amax(FXN)*np.pi
+    FY = FYN/np.amax(FYN)*np.pi
+    t_2 = nufft2(field, FX*ss, FY*ss, size=[nnv2, nnu2], sign=iflag, eps=eps)
+    FX = FX/np.amax(FX)*np.pi
+    FY = FY/np.amax(FY)*np.pi
+    result = nuifft2(Hn*t_2, FX*ss, FY*ss, size=[nv, nu], sign=-iflag, eps=eps)
+    return result
+
+
+
+ +
+ +
+ + +

+ add_phase(field, new_phase) + +

+ + +
+ +

Definition for adding a phase to a given complex field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
  • + new_phase + – +
    +
           Complex phase.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex64 +) – +
    +

    Complex field.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def add_phase(field, new_phase):
+    """
+    Definition for adding a phase to a given complex field.
+
+    Parameters
+    ----------
+    field        : np.complex64
+                   Complex field.
+    new_phase    : np.complex64
+                   Complex phase.
+
+    Returns
+    -------
+    new_field    : np.complex64
+                   Complex field.
+    """
+    phase = calculate_phase(field)
+    amplitude = calculate_amplitude(field)
+    new_field = amplitude*np.cos(phase+new_phase) + \
+        1j*amplitude*np.sin(phase+new_phase)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ add_random_phase(field) + +

+ + +
+ +

Definition for adding a random phase to a given complex field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex64 +) – +
    +

    Complex field.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def add_random_phase(field):
+    """
+    Definition for adding a random phase to a given complex field.
+
+    Parameters
+    ----------
+    field        : np.complex64
+                   Complex field.
+
+    Returns
+    -------
+    new_field    : np.complex64
+                   Complex field.
+    """
+    random_phase = np.pi*np.random.random(field.shape)
+    new_field = add_phase(field, random_phase)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ adjust_phase_only_slm_range(native_range, working_wavelength, native_wavelength) + +

+ + +
+ +

Definition for calculating the phase range of the Spatial Light Modulator (SLM) for a given wavelength. Here you prove maximum angle as the lower bound is typically zero. If the lower bound isn't zero in angles, you can use this very same definition for calculating lower angular bound as well.

+ + +

Parameters:

+
    +
  • + native_range + – +
    +
                 Native range of the phase only SLM in radians (i.e. two pi).
    +
    +
    +
  • +
  • + working_wavelength + (float) + – +
    +
                 Wavelength of the illumination source or some working wavelength.
    +
    +
    +
  • +
  • + native_wavelength + – +
    +
                 Wavelength which the SLM is designed for.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_range ( float +) – +
    +

    Calculated phase range in radians.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def adjust_phase_only_slm_range(native_range, working_wavelength, native_wavelength):
+    """
+    Definition for calculating the phase range of the Spatial Light Modulator (SLM) for a given wavelength. Here you prove maximum angle as the lower bound is typically zero. If the lower bound isn't zero in angles, you can use this very same definition for calculating lower angular bound as well.
+
+    Parameters
+    ----------
+    native_range       : float
+                         Native range of the phase only SLM in radians (i.e. two pi).
+    working_wavelength : float
+                         Wavelength of the illumination source or some working wavelength.
+    native_wavelength  : float
+                         Wavelength which the SLM is designed for.
+
+    Returns
+    -------
+    new_range          : float
+                         Calculated phase range in radians.
+    """
+    new_range = native_range/working_wavelength*native_wavelength
+    return new_range
+
+
+
+ +
+ +
+ + +

+ angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate angular spectrum based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate angular spectrum based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu/2*dx, nu/2*dx, nu)
+    y = np.linspace(-nv/2*dx, nv/2*dx, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))
+    h = np.fft.fft2(np.fft.fftshift(h))*dx**2
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = h*U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    return result
+
+
+
+ +
+ +
+ + +

+ band_extended_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range." Optics Letters 45.6 (2020): 1543-1546.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def band_extended_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range." Optics Letters 45.6 (2020): 1543-1546.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    iflag = -1
+    eps = 10**(-12)
+    nv, nu = field.shape
+    l = nu*dx
+    x = np.linspace(-l/2, l/2, nu)
+    y = np.linspace(-l/2, l/2, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    fx = np.linspace(-1./2./dx, 1./2./dx, nu)
+    fy = np.linspace(-1./2./dx, 1./2./dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    K = nu/2/np.amax(fx)
+    fcn = 1./2*(nu/wavelength/np.abs(distance))**0.5
+    ss = np.abs(fcn)/np.amax(np.abs(fx))
+    zc = nu*dx**2/wavelength
+    if np.abs(distance) < zc:
+        fxn = fx
+        fyn = fy
+    else:
+        fxn = fx*ss
+        fyn = fy*ss
+    FXN, FYN = np.meshgrid(fxn, fyn)
+    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)
+    X = X/np.amax(X)*np.pi
+    Y = Y/np.amax(Y)*np.pi
+    t_asmNUFT = nufft2(field, X*ss, Y*ss, sign=iflag, eps=eps)
+    result = nuifft2(Hn*t_asmNUFT, X*ss, Y*ss, sign=-iflag, eps=eps)
+    return result
+
+
+
+ +
+ +
+ + +

+ band_limited_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def band_limited_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu/2*dx, nu/2*dx, nu)
+    y = np.linspace(-nv/2*dx, nv/2*dx, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))
+    h = np.fft.fft2(np.fft.fftshift(h))*dx**2
+    flimx = np.ceil(1/(((2*distance*(1./(nu)))**2+1)**0.5*wavelength))
+    flimy = np.ceil(1/(((2*distance*(1./(nv)))**2+1)**0.5*wavelength))
+    mask = np.zeros((nu, nv), dtype=np.complex64)
+    mask = (np.abs(X) < flimx) & (np.abs(Y) < flimy)
+    mask = set_amplitude(h, mask)
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = mask*U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    return result
+
+
+
+ +
+ +
+ + +

+ calculate_intensity(field) + +

+ + +
+ +

Definition to calculate intensity of a single or multiple given electric field(s).

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Electric fields or an electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +intensity ( float +) – +
    +

    Intensity or intensities of electric field(s).

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
def calculate_intensity(field):
+    """
+    Definition to calculate intensity of a single or multiple given electric field(s).
+
+    Parameters
+    ----------
+    field        : ndarray.complex or complex
+                   Electric fields or an electric field.
+
+    Returns
+    -------
+    intensity    : float
+                   Intensity or intensities of electric field(s).
+    """
+    intensity = np.abs(field)**2
+    return intensity
+
+
+
+ +
+ +
+ + +

+ distance_between_two_points(point1, point2) + +

+ + +
+ +

Definition to calculate distance between two given points.

+ + +

Parameters:

+
    +
  • + point1 + – +
    +
          First point in X,Y,Z.
    +
    +
    +
  • +
  • + point2 + – +
    +
          Second point in X,Y,Z.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +distance ( float +) – +
    +

    Distance in between given two points.

    +
    +
  • +
+ +
+ Source code in odak/tools/vector.py +
77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def distance_between_two_points(point1, point2):
+    """
+    Definition to calculate distance between two given points.
+
+    Parameters
+    ----------
+    point1      : list
+                  First point in X,Y,Z.
+    point2      : list
+                  Second point in X,Y,Z.
+
+    Returns
+    ----------
+    distance    : float
+                  Distance in between given two points.
+    """
+    point1 = np.asarray(point1)
+    point2 = np.asarray(point2)
+    if len(point1.shape) == 1 and len(point2.shape) == 1:
+        distance = np.sqrt(np.sum((point1-point2)**2))
+    elif len(point1.shape) == 2 or len(point2.shape) == 2:
+        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))
+    return distance
+
+
+
+ +
+ +
+ + +

+ double_convergence(nx, ny, k, r, dx) + +

+ + +
+ +

A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. "Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm." Optics express 26.8 (2018): 10140-10151.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + r + – +
    +
         The distance between location of a light source and an image plane.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +function ( ndarray +) – +
    +

    Generated phase pattern for a Gerchberg-Saxton method.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def double_convergence(nx, ny, k, r, dx):
+    """
+    A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. "Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm." Optics express 26.8 (2018): 10140-10151.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    r          : float
+                 The distance between location of a light source and an image plane.
+    dx         : float
+                 Pixel pitch.
+
+    Returns
+    -------
+    function   : ndarray
+                 Generated phase pattern for a Gerchberg-Saxton method.
+    """
+    size = [ny, nx]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    w = np.exp(1j*k*Z/r)
+    return w
+
+
+
+ +
+ +
+ + +

+ electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0) + +

+ + +
+ +

Definition to return state of a plane wave at a particular distance and time.

+ + +

Parameters:

+
    +
  • + amplitude + – +
    +
           Amplitude of a wave.
    +
    +
    +
  • +
  • + opd + – +
    +
           Optical path difference in mm.
    +
    +
    +
  • +
  • + k + – +
    +
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.
    +
    +
    +
  • +
  • + phase + – +
    +
           Initial phase of a wave.
    +
    +
    +
  • +
  • + w + – +
    +
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
    +
    +
    +
  • +
  • + t + – +
    +
           Time in seconds.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( complex +) – +
    +

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
def electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0):
+    """
+    Definition to return state of a plane wave at a particular distance and time.
+
+    Parameters
+    ----------
+    amplitude    : float
+                   Amplitude of a wave.
+    opd          : float
+                   Optical path difference in mm.
+    k            : float
+                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.
+    phase        : float
+                   Initial phase of a wave.
+    w            : float
+                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
+    t            : float
+                   Time in seconds.
+
+    Returns
+    -------
+    field        : complex
+                   A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).
+    """
+    field = amplitude*np.exp(1j*(-w*t+opd*k+phase))/opd**2
+    return field
+
+
+
+ +
+ +
+ + +

+ fraunhofer(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate Fraunhofer based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate Fraunhofer based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    l = nu*dx
+    l2 = wavelength*distance/dx
+    dx2 = wavelength*distance/l
+    fx = np.linspace(-l2/2., l2/2., nu)
+    fy = np.linspace(-l2/2., l2/2., nv)
+    FX, FY = np.meshgrid(fx, fy)
+    FZ = FX**2+FY**2
+    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \
+        np.exp(1j*k/(2*distance)*FZ)
+    result = c*np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field)))*dx**2
+    return result
+
+
+
+ +
+ +
+ + +

+ fraunhofer_equal_size_adjust(field, distance, dx, wavelength) + +

+ + +
+ +

A definition to match the physical size of the original field with the propagated field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer_equal_size_adjust(field, distance, dx, wavelength):
+    """
+    A definition to match the physical size of the original field with the propagated field.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    new_field        : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    l1 = nu*dx
+    l2 = wavelength*distance/dx
+    m = l1/l2
+    px = int(m*nu)
+    py = int(m*nv)
+    nx = int(field.shape[0]/2-px/2)
+    ny = int(field.shape[1]/2-py/2)
+    new_field = np.copy(field[nx:nx+px, ny:ny+py])
+    return new_field
+
+
+
+ +
+ +
+ + +

+ fraunhofer_inverse(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate Inverse Fraunhofer based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer_inverse(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate Inverse Fraunhofer based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    distance = np.abs(distance)
+    nv, nu = field.shape
+    l = nu*dx
+    l2 = wavelength*distance/dx
+    dx2 = wavelength*distance/l
+    fx = np.linspace(-l2/2., l2/2., nu)
+    fy = np.linspace(-l2/2., l2/2., nv)
+    FX, FY = np.meshgrid(fx, fy)
+    FZ = FX**2+FY**2
+    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \
+        np.exp(1j*k/(2*distance)*FZ)
+    result = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(field/dx**2/c)))
+    return result
+
+
+
+ +
+ +
+ + +

+ generate_complex_field(amplitude, phase) + +

+ + +
+ +

Definition to generate a complex field with a given amplitude and phase.

+ + +

Parameters:

+
    +
  • + amplitude + – +
    +
                Amplitude of the field.
    +
    +
    +
  • +
  • + phase + – +
    +
                Phase of the field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( ndarray +) – +
    +

    Complex field.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def generate_complex_field(amplitude, phase):
+    """
+    Definition to generate a complex field with a given amplitude and phase.
+
+    Parameters
+    ----------
+    amplitude         : ndarray
+                        Amplitude of the field.
+    phase             : ndarray
+                        Phase of the field.
+
+    Returns
+    -------
+    field             : ndarray
+                        Complex field.
+    """
+    field = amplitude*np.cos(phase)+1j*amplitude*np.sin(phase)
+    return field
+
+
+
+ +
+ +
+ + +

+ gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None) + +

+ + +
+ +

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
  • + initial_phase + – +
    +
               Phase to be added to the initial value.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( complex +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
  • +reconstruction ( complex +) – +
    +

    Calculated reconstruction using calculated hologram.

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None):
+    """
+    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.
+
+    Parameters
+    ----------
+    field            : np.complex64
+                       Complex field (MxN).
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    slm_range        : float
+                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
+    initial_phase    : np.complex64
+                       Phase to be added to the initial value.
+
+    Returns
+    -------
+    hologram         : np.complex
+                       Calculated complex hologram.
+    reconstruction   : np.complex
+                       Calculated reconstruction using calculated hologram. 
+    """
+    k = wavenumber(wavelength)
+    target = calculate_amplitude(field)
+    hologram = generate_complex_field(np.ones(field.shape), 0)
+    hologram = zero_pad(hologram)
+    if type(initial_phase) == type(None):
+        hologram = add_random_phase(hologram)
+    else:
+        initial_phase = zero_pad(initial_phase)
+        hologram = add_phase(hologram, initial_phase)
+    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]
+    orig_shape = [int(field.shape[0]/2.), int(field.shape[1]/2.)]
+    for i in tqdm(range(n_iterations), leave=False):
+        reconstruction = propagate_beam(
+            hologram, k, distance, dx, wavelength, propagation_type)
+        new_target = calculate_amplitude(reconstruction)
+        new_target[
+            center[0]-orig_shape[0]:center[0]+orig_shape[0],
+            center[1]-orig_shape[1]:center[1]+orig_shape[1]
+        ] = target
+        reconstruction = generate_complex_field(
+            new_target, calculate_phase(reconstruction))
+        hologram = propagate_beam(
+            reconstruction, k, -distance, dx, wavelength, propagation_type)
+        hologram = generate_complex_field(1, calculate_phase(hologram))
+        hologram = hologram[
+            center[0]-orig_shape[0]:center[0]+orig_shape[0],
+            center[1]-orig_shape[1]:center[1]+orig_shape[1],
+        ]
+        hologram = zero_pad(hologram)
+    reconstruction = propagate_beam(
+        hologram, k, distance, dx, wavelength, propagation_type)
+    hologram = hologram[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    reconstruction = reconstruction[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    return hologram, reconstruction
+
+
+
+ +
+ +
+ + +

+ gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None) + +

+ + +
+ +

Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. "30.4: Multi‐plane holographic display with a uniform 3D Gerchberg‐Saxton algorithm." SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.

+ + +

Parameters:

+
    +
  • + fields + – +
    +
               Complex fields (MxN).
    +
    +
    +
  • +
  • + distances + – +
    +
               Propagation distances.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
  • + initial_phase + – +
    +
               Phase to be added to the initial value.
    +
    +
    +
  • +
  • + target_type + – +
    +
               Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. "Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg–Saxton algorithm." Applied optics 54.23 (2015): 6994-7001.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( complex +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None):
+    """
+    Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. "30.4: Multi‐plane holographic display with a uniform 3D Gerchberg‐Saxton algorithm." SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.
+
+    Parameters
+    ----------
+    fields           : np.complex64
+                       Complex fields (MxN).
+    distances        : list
+                       Propagation distances.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    slm_range        : float
+                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
+    initial_phase    : np.complex64
+                       Phase to be added to the initial value.
+    target_type      : str
+                       Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. "Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg–Saxton algorithm." Applied optics 54.23 (2015): 6994-7001. 
+
+    Returns
+    -------
+    hologram         : np.complex
+                       Calculated complex hologram.
+    """
+    k = wavenumber(wavelength)
+    targets = calculate_amplitude(np.asarray(fields)).astype(np.float64)
+    hologram = generate_complex_field(np.ones(targets[0].shape), 0)
+    hologram = zero_pad(hologram)
+    if type(initial_phase) == type(None):
+        hologram = add_random_phase(hologram)
+    else:
+        initial_phase = zero_pad(initial_phase)
+        hologram = add_phase(hologram, initial_phase)
+    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]
+    orig_shape = [int(fields[0].shape[0]/2.), int(fields[0].shape[1]/2.)]
+    holograms = np.zeros(
+        (len(distances), hologram.shape[0], hologram.shape[1]), dtype=np.complex64)
+    for i in tqdm(range(n_iterations), leave=False):
+        for distance_id in tqdm(range(len(distances)), leave=False):
+            distance = distances[distance_id]
+            reconstruction = propagate_beam(
+                hologram, k, distance, dx, wavelength, propagation_type)
+            if target_type == 'double constraint':
+                if type(coefficients) == type(None):
+                    raise Exception(
+                        "Provide coeeficients of alpha,beta and gamma for double constraint.")
+                alpha = coefficients[0]
+                beta = coefficients[1]
+                gamma = coefficients[2]
+                target_current = 2*alpha * \
+                    np.copy(targets[distance_id])-beta * \
+                    calculate_amplitude(reconstruction)
+                target_current[target_current == 0] = gamma * \
+                    np.abs(reconstruction[target_current == 0])
+            elif target_type == 'no constraint':
+                target_current = np.abs(targets[distance_id])
+            new_target = calculate_amplitude(reconstruction)
+            new_target[
+                center[0]-orig_shape[0]:center[0]+orig_shape[0],
+                center[1]-orig_shape[1]:center[1]+orig_shape[1]
+            ] = target_current
+            reconstruction = generate_complex_field(
+                new_target, calculate_phase(reconstruction))
+            hologram_layer = propagate_beam(
+                reconstruction, k, -distance, dx, wavelength, propagation_type)
+            hologram_layer = generate_complex_field(
+                1., calculate_phase(hologram_layer))
+            hologram_layer = hologram_layer[
+                center[0]-orig_shape[0]:center[0]+orig_shape[0],
+                center[1]-orig_shape[1]:center[1]+orig_shape[1]
+            ]
+            hologram_layer = zero_pad(hologram_layer)
+            holograms[distance_id] = hologram_layer
+        hologram = np.sum(holograms, axis=0)
+    hologram = hologram[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    return hologram
+
+
+
+ +
+ +
+ + +

+ impulse_response_fresnel(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate impulse response based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def impulse_response_fresnel(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate impulse response based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu / 2 * dx, nu / 2 * dx, nu)
+    y = np.linspace(-nv / 2 * dx, nv / 2 * dx, nv)
+    X, Y = np.meshgrid(x, y)
+    h = 1. / (1j * wavelength * distance) * np.exp(1j * k / (2 * distance) * (X ** 2 + Y ** 2))
+    H = np.fft.fft2(np.fft.fftshift(h))
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = H * U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    result = np.roll(result, shift = (1, 1), axis = (0, 1))
+    return result
+
+
+
+ +
+ +
+ + +

+ linear_grating(nx, ny, every=2, add=3.14, axis='x') + +

+ + +
+ +

A definition to generate a linear grating.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + every + – +
    +
         Add the add value at every given number.
    +
    +
    +
  • +
  • + add + – +
    +
         Angle to be added.
    +
    +
    +
  • +
  • + axis + – +
    +
         Axis eiter X,Y or both.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( ndarray +) – +
    +

    Linear grating term.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
def linear_grating(nx, ny, every=2, add=3.14, axis='x'):
+    """
+    A definition to generate a linear grating.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    every      : int
+                 Add the add value at every given number.
+    add        : float
+                 Angle to be added.
+    axis       : string
+                 Axis eiter X,Y or both.
+
+    Returns
+    -------
+    field      : ndarray
+                 Linear grating term.
+    """
+    grating = np.zeros((nx, ny), dtype=np.complex64)
+    if axis == 'x':
+        grating[::every, :] = np.exp(1j*add)
+    if axis == 'y':
+        grating[:, ::every] = np.exp(1j*add)
+    if axis == 'xy':
+        checker = np.indices((nx, ny)).sum(axis=0) % every
+        checker += 1
+        checker = checker % 2
+        grating = np.exp(1j*checker*add)
+    return grating
+
+
+
+ +
+ +
+ + +

+ nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Size.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    Inverse NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list
+                  Size.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  Inverse NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    result = finufft.nufft2d2(
+        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)
+    if type(size) == type(None):
+        result = result.reshape(field.shape)
+    else:
+        result = result.reshape(size)
+    return result
+
+
+
+ +
+ +
+ + +

+ nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12) + +

+ + +
+ +

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

+ + +

Parameters:

+
    +
  • + field + – +
    +
          Input field.
    +
    +
    +
  • +
  • + fx + – +
    +
          Frequencies along x axis.
    +
    +
    +
  • +
  • + fy + – +
    +
          Frequencies along y axis.
    +
    +
    +
  • +
  • + size + – +
    +
          Shape of the NUFFT calculated for an input field.
    +
    +
    +
  • +
  • + sign + – +
    +
          Sign of the exponential used in NUFFT kernel.
    +
    +
    +
  • +
  • + eps + – +
    +
          Accuracy of NUFFT.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( ndarray +) – +
    +

    NUFFT of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):
+    """
+    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).
+
+    Parameters
+    ----------
+    field       : ndarray
+                  Input field.
+    fx          : ndarray
+                  Frequencies along x axis.
+    fy          : ndarray
+                  Frequencies along y axis.
+    size        : list or ndarray
+                  Shape of the NUFFT calculated for an input field.
+    sign        : float
+                  Sign of the exponential used in NUFFT kernel.
+    eps         : float
+                  Accuracy of NUFFT.
+
+    Returns
+    ----------
+    result      : ndarray
+                  NUFFT of the input field.
+    """
+    try:
+        import finufft
+    except:
+        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')
+    image = np.copy(field).astype(np.complex128)
+    if type(size) == type(None):
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            image.shape,
+            eps=eps,
+            isign=sign
+        )
+    else:
+        result = finufft.nufft2d1(
+            fx.flatten(),
+            fy.flatten(),
+            image.flatten(),
+            (size[0], size[1]),
+            eps=eps,
+            isign=sign
+        )
+    result = np.asarray(result)
+    return result
+
+
+
+ +
+ +
+ + +

+ prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x') + +

+ + +
+ +

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + angle + – +
    +
         Tilt angle of the prism in degrees.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
  • + axis + – +
    +
         Axis of the prism.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +prism ( ndarray +) – +
    +

    Generated phase function for a prism.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
def prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x'):
+    """
+    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    angle      : float
+                 Tilt angle of the prism in degrees.
+    dx         : float
+                 Pixel pitch.
+    axis       : str
+                 Axis of the prism.
+
+    Returns
+    -------
+    prism      : ndarray
+                 Generated phase function for a prism.
+    """
+    angle = np.radians(angle)
+    size = [ny, nx]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])
+    X, Y = np.meshgrid(x, y)
+    if axis == 'y':
+        prism = np.exp(-1j*k*np.sin(angle)*Y)
+    elif axis == 'x':
+        prism = np.exp(-1j*k*np.sin(angle)*X)
+    return prism
+
+
+
+ +
+ +
+ + +

+ produce_phase_only_slm_pattern(hologram, slm_range, filename=None, bits=8, default_range=6.28, illumination=None) + +

+ + +
+ +

Definition for producing a pattern for a phase only Spatial Light Modulator (SLM) using a given field.

+ + +

Parameters:

+
    +
  • + hologram + – +
    +
                 Input holographic field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
                 Range of the phase only SLM in radians for a working wavelength (i.e. two pi). See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + filename + – +
    +
                 Optional variable, if provided the patterns will be save to given location.
    +
    +
    +
  • +
  • + bits + – +
    +
                 Quantization bits.
    +
    +
    +
  • +
  • + default_range + – +
    +
                 Default range of phase only SLM.
    +
    +
    +
  • +
  • + illumination + – +
    +
                 Spatial illumination distribution.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +pattern ( complex64 +) – +
    +

    Adjusted phase only pattern.

    +
    +
  • +
  • +hologram_digital ( int +) – +
    +

    Digital representation of the hologram.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def produce_phase_only_slm_pattern(hologram, slm_range, filename=None, bits=8, default_range=6.28, illumination=None):
+    """
+    Definition for producing a pattern for a phase only Spatial Light Modulator (SLM) using a given field.
+
+    Parameters
+    ----------
+    hologram           : np.complex64
+                         Input holographic field.
+    slm_range          : float
+                         Range of the phase only SLM in radians for a working wavelength (i.e. two pi). See odak.wave.adjust_phase_only_slm_range() for more.
+    filename           : str
+                         Optional variable, if provided the patterns will be save to given location.
+    bits               : int
+                         Quantization bits.
+    default_range      : float 
+                         Default range of phase only SLM.
+    illumination       : np.ndarray
+                         Spatial illumination distribution.
+
+    Returns
+    -------
+    pattern            : np.complex64
+                         Adjusted phase only pattern.
+    hologram_digital   : np.int
+                         Digital representation of the hologram.
+    """
+    #hologram_phase   = calculate_phase(hologram) % default_range
+    hologram_phase = calculate_phase(hologram)
+    hologram_phase = hologram_phase % slm_range
+    hologram_phase /= slm_range
+    hologram_phase *= 2**bits
+    hologram_phase = hologram_phase.astype(np.int32)
+    hologram_digital = np.copy(hologram_phase)
+    if type(filename) != type(None):
+        save_image(
+            filename,
+            hologram_phase,
+            cmin=0,
+            cmax=2**bits
+        )
+    hologram_phase = hologram_phase.astype(np.float64)
+    hologram_phase *= slm_range/2**bits
+    if type(illumination) == type(None):
+        A = 1.
+    else:
+        A = illumination
+    return A*np.cos(hologram_phase)+A*1j*np.sin(hologram_phase), hologram_digital
+
+
+
+ +
+ +
+ + +

+ propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel') + +

+ + +
+ +

Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with "Computational Fourier Optics" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see "Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields".

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
def propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel'):
+    """
+    Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with "Computational Fourier Optics" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see "Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields".
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    if propagation_type == 'Rayleigh-Sommerfeld':
+        result = rayleigh_sommerfeld(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Angular Spectrum':
+        result = angular_spectrum(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Impulse Response Fresnel':
+        result = impulse_response_fresnel(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Bandlimited Angular Spectrum':
+        result = band_limited_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Bandextended Angular Spectrum':
+        result = band_extended_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Adaptive Sampling Angular Spectrum':
+        result = adaptive_sampling_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Transfer Function Fresnel':
+        result = transfer_function_fresnel(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Fraunhofer':
+        result = fraunhofer(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Fraunhofer Inverse':
+        result = fraunhofer_inverse(field, k, distance, dx, wavelength)
+    else:
+        raise Exception("Unknown propagation type selected.")
+    return result
+
+
+
+ +
+ +
+ + +

+ propagate_field(points0, points1, field0, wave_number, direction=1) + +

+ + +
+ +

Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.

+ + +

Parameters:

+
    +
  • + points0 + – +
    +
            Start points (i.e. odak.tools.grid_sample).
    +
    +
    +
  • +
  • + points1 + – +
    +
            End points (ie. odak.tools.grid_sample).
    +
    +
    +
  • +
  • + field0 + – +
    +
            Field for given starting points.
    +
    +
    +
  • +
  • + wave_number + – +
    +
            Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + direction + – +
    +
            For propagating in forward direction set as 1, otherwise -1.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field1 ( ndarray +) – +
    +

    Field for given end points.

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def propagate_field(points0, points1, field0, wave_number, direction=1):
+    """
+    Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.
+
+    Parameters
+    ----------
+    points0       : ndarray
+                    Start points (i.e. odak.tools.grid_sample).
+    points1       : ndarray
+                    End points (ie. odak.tools.grid_sample).
+    field0        : ndarray
+                    Field for given starting points.
+    wave_number   : float
+                    Wave number of a wave, see odak.wave.wavenumber for more.
+    direction     : float
+                    For propagating in forward direction set as 1, otherwise -1.
+
+    Returns
+    -------
+    field1        : ndarray
+                    Field for given end points.
+    """
+    field1 = np.zeros(points1.shape[0], dtype=np.complex64)
+    for point_id in range(points0.shape[0]):
+        point = points0[point_id]
+        distances = distance_between_two_points(
+            point,
+            points1
+        )
+        field1 += electric_field_per_plane_wave(
+            calculate_amplitude(field0[point_id]),
+            distances*direction,
+            wave_number,
+            phase=calculate_phase(field0[point_id])
+        )
+    return field1
+
+
+
+ +
+ +
+ + +

+ propagate_plane_waves(field, opd, k, w=0, t=0) + +

+ + +
+ +

Definition to propagate a field representing a plane wave at a particular distance and time.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
  • + opd + – +
    +
           Optical path difference in mm.
    +
    +
    +
  • +
  • + k + – +
    +
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.
    +
    +
    +
  • +
  • + w + – +
    +
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
    +
    +
    +
  • +
  • + t + – +
    +
           Time in seconds.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex +) – +
    +

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
def propagate_plane_waves(field, opd, k, w=0, t=0):
+    """
+    Definition to propagate a field representing a plane wave at a particular distance and time.
+
+    Parameters
+    ----------
+    field        : complex
+                   Complex field.
+    opd          : float
+                   Optical path difference in mm.
+    k            : float
+                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.
+    w            : float
+                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
+    t            : float
+                   Time in seconds.
+
+    Returns
+    -------
+    new_field     : complex
+                    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).
+    """
+    new_field = field*np.exp(1j*(-w*t+opd*k))/opd**2
+    return new_field
+
+
+
+ +
+ +
+ + +

+ quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]) + +

+ + +
+ +

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + focal + – +
    +
         Focal length of the quadratic phase function.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
  • + offset + – +
    +
         Deviation from the center along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +function ( ndarray +) – +
    +

    Generated quadratic phase function.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
+    """ 
+    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    focal      : float
+                 Focal length of the quadratic phase function.
+    dx         : float
+                 Pixel pitch.
+    offset     : list
+                 Deviation from the center along X and Y axes.
+
+    Returns
+    -------
+    function   : ndarray
+                 Generated quadratic phase function.
+    """
+    size = [nx, ny]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])-offset[1]*dx
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])-offset[0]*dx
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    qwf = np.exp(1j*k*0.5*np.sin(Z/focal))
+    return qwf
+
+
+
+ +
+ +
+ + +

+ rayleigh_resolution(diameter, focal=None, wavelength=0.0005) + +

+ + +
+ +

Definition to calculate rayleigh resolution limit of a lens with a certain focal length and an aperture. Lens is assumed to be focusing a plane wave at a focal distance.

+ + +
+ Parameter +

diameter : float + Diameter of a lens. +focal : float + Focal length of a lens, when focal length is provided, spatial resolution is provided at the focal plane. When focal length isn't provided angular resolution is provided. +wavelength : float + Wavelength of light.

+
+ +

Returns:

+
    +
  • +resolution ( float +) – +
    +

    Resolvable angular or spatial spot size, see focal in parameters to know what to expect.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
def rayleigh_resolution(diameter, focal=None, wavelength=0.0005):
+    """
+    Definition to calculate rayleigh resolution limit of a lens with a certain focal length and an aperture. Lens is assumed to be focusing a plane wave at a focal distance.
+
+    Parameter
+    ---------
+    diameter    : float
+                  Diameter of a lens.
+    focal       : float
+                  Focal length of a lens, when focal length is provided, spatial resolution is provided at the focal plane. When focal length isn't provided angular resolution is provided.
+    wavelength  : float
+                  Wavelength of light.
+
+    Returns
+    --------
+    resolution  : float
+                  Resolvable angular or spatial spot size, see focal in parameters to know what to expect.
+
+    """
+    resolution = 1.22*wavelength/diameter
+    if type(focal) != type(None):
+        resolution *= focal
+    return resolution
+
+
+
+ +
+ +
+ + +

+ rayleigh_sommerfeld(field, k, distance, dx, wavelength) + +

+ + +
+ +

Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def rayleigh_sommerfeld(field, k, distance, dx, wavelength):
+    """
+    Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nv * dx / 2, nv * dx / 2, nv)
+    y = np.linspace(-nu * dx / 2, nu * dx / 2, nu)
+    X, Y = np.meshgrid(x, y)
+    Z = X ** 2 + Y ** 2
+    result = np.zeros(field.shape, dtype=np.complex64)
+    direction = int(distance/np.abs(distance))
+    for i in range(nu):
+        for j in range(nv):
+            if field[i, j] != 0:
+                r01 = np.sqrt(distance ** 2 + (X - X[i, j]) ** 2 + (Y - Y[i, j]) ** 2) * direction
+                cosnr01 = np.cos(distance / r01)
+                result += field[i, j] * np.exp(1j * k * r01) / r01 * cosnr01
+    result *= 1. / (1j * wavelength)
+    return result
+
+
+
+ +
+ +
+ + +

+ rotationspeed(wavelength, c=3 * 10 ** 11) + +

+ + +
+ +

Definition for calculating rotation speed of a wave (w in A*e^(j(wt+phi))).

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
           Wavelength of a wave in mm.
    +
    +
    +
  • +
  • + c + – +
    +
           Speed of wave in mm/seconds. Default is the speed of light in the void!
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +w ( float +) – +
    +

    Rotation speed.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
def rotationspeed(wavelength, c=3*10**11):
+    """
+    Definition for calculating rotation speed of a wave (w in A*e^(j(wt+phi))).
+
+    Parameters
+    ----------
+    wavelength   : float
+                   Wavelength of a wave in mm.
+    c            : float
+                   Speed of wave in mm/seconds. Default is the speed of light in the void!
+
+    Returns
+    -------
+    w            : float
+                   Rotation speed.
+
+    """
+    f = c*wavelength
+    w = 2*np.pi*f
+    return w
+
+
+
+ +
+ +
+ + +

+ set_amplitude(field, amplitude) + +

+ + +
+ +

Definition to keep phase as is and change the amplitude of a given field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
  • + amplitude + – +
    +
           Amplitudes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex64 +) – +
    +

    Complex field.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
def set_amplitude(field, amplitude):
+    """
+    Definition to keep phase as is and change the amplitude of a given field.
+
+    Parameters
+    ----------
+    field        : np.complex64
+                   Complex field.
+    amplitude    : np.array or np.complex64
+                   Amplitudes.
+
+    Returns
+    -------
+    new_field    : np.complex64
+                   Complex field.
+    """
+    amplitude = calculate_amplitude(amplitude)
+    phase = calculate_phase(field)
+    new_field = amplitude*np.cos(phase)+1j*amplitude*np.sin(phase)
+    return new_field
+
+
+
+ +
+ +
+ + +

+ transfer_function_fresnel(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def transfer_function_fresnel(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+
+    """
+    nv, nu = field.shape
+    fx = np.linspace(-1. / 2. /dx, 1. /2. /dx, nu)
+    fy = np.linspace(-1. / 2. /dx, 1. /2. /dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    H = np.exp(1j * k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5)
+    U1 = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(field)))
+    U2 = H * U1
+    result = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(U2)))
+    return result
+
+
+
+ +
+ +
+ + +

+ wavenumber(wavelength) + +

+ + +
+ +

Definition for calculating the wavenumber of a plane wave.

+ + +

Parameters:

+
    +
  • + wavelength + – +
    +
           Wavelength of a wave in mm.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +k ( float +) – +
    +

    Wave number for a given wavelength.

    +
    +
  • +
+ +
+ Source code in odak/wave/__init__.py +
59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
def wavenumber(wavelength):
+    """
+    Definition for calculating the wavenumber of a plane wave.
+
+    Parameters
+    ----------
+    wavelength   : float
+                   Wavelength of a wave in mm.
+
+    Returns
+    -------
+    k            : float
+                   Wave number for a given wavelength.
+    """
+    k = 2*np.pi/wavelength
+    return k
+
+
+
+ +
+ +
+ + +

+ zero_pad(field, size=None, method='center') + +

+ + +
+ +

Definition to zero pad a MxN array to 2Mx2N array.

+ + +

Parameters:

+
    +
  • + field + – +
    +
                Input field MxN array.
    +
    +
    +
  • +
  • + size + – +
    +
                Size to be zeropadded.
    +
    +
    +
  • +
  • + method + – +
    +
                Zeropad either by placing the content to center or to the left.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field_zero_padded ( ndarray +) – +
    +

    Zeropadded version of the input field.

    +
    +
  • +
+ +
+ Source code in odak/tools/matrix.py +
def zero_pad(field, size=None, method='center'):
+    """
+    Definition to zero pad a MxN array to 2Mx2N array.
+
+    Parameters
+    ----------
+    field             : ndarray
+                        Input field MxN array.
+    size              : list
+                        Size to be zeropadded.
+    method            : str
+                        Zeropad either by placing the content to center or to the left.
+
+    Returns
+    ----------
+    field_zero_padded : ndarray
+                        Zeropadded version of the input field.
+    """
+    if type(size) == type(None):
+        hx = int(np.ceil(field.shape[0])/2)
+        hy = int(np.ceil(field.shape[1])/2)
+    else:
+        hx = int(np.ceil((size[0]-field.shape[0])/2))
+        hy = int(np.ceil((size[1]-field.shape[1])/2))
+    if method == 'center':
+        field_zero_padded = np.pad(
+            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))
+    elif method == 'left aligned':
+        field_zero_padded = np.pad(
+            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))
+    if type(size) != type(None):
+        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]
+    return field_zero_padded
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product." Optics Letters 45.16 (2020): 4416-4419.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product." Optics Letters 45.16 (2020): 4416-4419.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    iflag = -1
+    eps = 10**(-12)
+    nv, nu = field.shape
+    l = nu*dx
+    x = np.linspace(-l/2, l/2, nu)
+    y = np.linspace(-l/2, l/2, nv)
+    X, Y = np.meshgrid(x, y)
+    fx = np.linspace(-1./2./dx, 1./2./dx, nu)
+    fy = np.linspace(-1./2./dx, 1./2./dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    forig = 1./2./dx
+    fc2 = 1./2*(nu/wavelength/np.abs(distance))**0.5
+    ss = np.abs(fc2)/forig
+    zc = nu*dx**2/wavelength
+    K = nu/2/np.amax(np.abs(fx))
+    m = 2
+    nnu2 = m*nu
+    nnv2 = m*nv
+    fxn = np.linspace(-1./2./dx, 1./2./dx, nnu2)
+    fyn = np.linspace(-1./2./dx, 1./2./dx, nnv2)
+    if np.abs(distance) > zc*2:
+        fxn = fxn*ss
+        fyn = fyn*ss
+    FXN, FYN = np.meshgrid(fxn, fyn)
+    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)
+    FX = FXN/np.amax(FXN)*np.pi
+    FY = FYN/np.amax(FYN)*np.pi
+    t_2 = nufft2(field, FX*ss, FY*ss, size=[nnv2, nnu2], sign=iflag, eps=eps)
+    FX = FX/np.amax(FX)*np.pi
+    FY = FY/np.amax(FY)*np.pi
+    result = nuifft2(Hn*t_2, FX*ss, FY*ss, size=[nv, nu], sign=-iflag, eps=eps)
+    return result
+
+
+
+ +
+ +
+ + +

+ angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate angular spectrum based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate angular spectrum based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu/2*dx, nu/2*dx, nu)
+    y = np.linspace(-nv/2*dx, nv/2*dx, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))
+    h = np.fft.fft2(np.fft.fftshift(h))*dx**2
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = h*U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    return result
+
+
+
+ +
+ +
+ + +

+ band_extended_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range." Optics Letters 45.6 (2020): 1543-1546.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def band_extended_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. "Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range." Optics Letters 45.6 (2020): 1543-1546.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    iflag = -1
+    eps = 10**(-12)
+    nv, nu = field.shape
+    l = nu*dx
+    x = np.linspace(-l/2, l/2, nu)
+    y = np.linspace(-l/2, l/2, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    fx = np.linspace(-1./2./dx, 1./2./dx, nu)
+    fy = np.linspace(-1./2./dx, 1./2./dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    K = nu/2/np.amax(fx)
+    fcn = 1./2*(nu/wavelength/np.abs(distance))**0.5
+    ss = np.abs(fcn)/np.amax(np.abs(fx))
+    zc = nu*dx**2/wavelength
+    if np.abs(distance) < zc:
+        fxn = fx
+        fyn = fy
+    else:
+        fxn = fx*ss
+        fyn = fy*ss
+    FXN, FYN = np.meshgrid(fxn, fyn)
+    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)
+    X = X/np.amax(X)*np.pi
+    Y = Y/np.amax(Y)*np.pi
+    t_asmNUFT = nufft2(field, X*ss, Y*ss, sign=iflag, eps=eps)
+    result = nuifft2(Hn*t_asmNUFT, X*ss, Y*ss, sign=-iflag, eps=eps)
+    return result
+
+
+
+ +
+ +
+ + +

+ band_limited_angular_spectrum(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def band_limited_angular_spectrum(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. "Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields." Optics express 17.22 (2009): 19662-19673.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu/2*dx, nu/2*dx, nu)
+    y = np.linspace(-nv/2*dx, nv/2*dx, nv)
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))
+    h = np.fft.fft2(np.fft.fftshift(h))*dx**2
+    flimx = np.ceil(1/(((2*distance*(1./(nu)))**2+1)**0.5*wavelength))
+    flimy = np.ceil(1/(((2*distance*(1./(nv)))**2+1)**0.5*wavelength))
+    mask = np.zeros((nu, nv), dtype=np.complex64)
+    mask = (np.abs(X) < flimx) & (np.abs(Y) < flimy)
+    mask = set_amplitude(h, mask)
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = mask*U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    return result
+
+
+
+ +
+ +
+ + +

+ fraunhofer(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate Fraunhofer based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate Fraunhofer based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    l = nu*dx
+    l2 = wavelength*distance/dx
+    dx2 = wavelength*distance/l
+    fx = np.linspace(-l2/2., l2/2., nu)
+    fy = np.linspace(-l2/2., l2/2., nv)
+    FX, FY = np.meshgrid(fx, fy)
+    FZ = FX**2+FY**2
+    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \
+        np.exp(1j*k/(2*distance)*FZ)
+    result = c*np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field)))*dx**2
+    return result
+
+
+
+ +
+ +
+ + +

+ fraunhofer_equal_size_adjust(field, distance, dx, wavelength) + +

+ + +
+ +

A definition to match the physical size of the original field with the propagated field.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer_equal_size_adjust(field, distance, dx, wavelength):
+    """
+    A definition to match the physical size of the original field with the propagated field.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    new_field        : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    l1 = nu*dx
+    l2 = wavelength*distance/dx
+    m = l1/l2
+    px = int(m*nu)
+    py = int(m*nv)
+    nx = int(field.shape[0]/2-px/2)
+    ny = int(field.shape[1]/2-py/2)
+    new_field = np.copy(field[nx:nx+px, ny:ny+py])
+    return new_field
+
+
+
+ +
+ +
+ + +

+ fraunhofer_inverse(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate Inverse Fraunhofer based beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def fraunhofer_inverse(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate Inverse Fraunhofer based beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    distance = np.abs(distance)
+    nv, nu = field.shape
+    l = nu*dx
+    l2 = wavelength*distance/dx
+    dx2 = wavelength*distance/l
+    fx = np.linspace(-l2/2., l2/2., nu)
+    fy = np.linspace(-l2/2., l2/2., nv)
+    FX, FY = np.meshgrid(fx, fy)
+    FZ = FX**2+FY**2
+    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \
+        np.exp(1j*k/(2*distance)*FZ)
+    result = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(field/dx**2/c)))
+    return result
+
+
+
+ +
+ +
+ + +

+ gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None) + +

+ + +
+ +

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
  • + initial_phase + – +
    +
               Phase to be added to the initial value.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( complex +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
  • +reconstruction ( complex +) – +
    +

    Calculated reconstruction using calculated hologram.

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None):
+    """
+    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. "A practical algorithm for the determination of phase from image and diffraction plane pictures." Optik 35 (1972): 237-246.
+
+    Parameters
+    ----------
+    field            : np.complex64
+                       Complex field (MxN).
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    slm_range        : float
+                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
+    initial_phase    : np.complex64
+                       Phase to be added to the initial value.
+
+    Returns
+    -------
+    hologram         : np.complex
+                       Calculated complex hologram.
+    reconstruction   : np.complex
+                       Calculated reconstruction using calculated hologram. 
+    """
+    k = wavenumber(wavelength)
+    target = calculate_amplitude(field)
+    hologram = generate_complex_field(np.ones(field.shape), 0)
+    hologram = zero_pad(hologram)
+    if type(initial_phase) == type(None):
+        hologram = add_random_phase(hologram)
+    else:
+        initial_phase = zero_pad(initial_phase)
+        hologram = add_phase(hologram, initial_phase)
+    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]
+    orig_shape = [int(field.shape[0]/2.), int(field.shape[1]/2.)]
+    for i in tqdm(range(n_iterations), leave=False):
+        reconstruction = propagate_beam(
+            hologram, k, distance, dx, wavelength, propagation_type)
+        new_target = calculate_amplitude(reconstruction)
+        new_target[
+            center[0]-orig_shape[0]:center[0]+orig_shape[0],
+            center[1]-orig_shape[1]:center[1]+orig_shape[1]
+        ] = target
+        reconstruction = generate_complex_field(
+            new_target, calculate_phase(reconstruction))
+        hologram = propagate_beam(
+            reconstruction, k, -distance, dx, wavelength, propagation_type)
+        hologram = generate_complex_field(1, calculate_phase(hologram))
+        hologram = hologram[
+            center[0]-orig_shape[0]:center[0]+orig_shape[0],
+            center[1]-orig_shape[1]:center[1]+orig_shape[1],
+        ]
+        hologram = zero_pad(hologram)
+    reconstruction = propagate_beam(
+        hologram, k, distance, dx, wavelength, propagation_type)
+    hologram = hologram[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    reconstruction = reconstruction[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    return hologram, reconstruction
+
+
+
+ +
+ +
+ + +

+ gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None) + +

+ + +
+ +

Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. "30.4: Multi‐plane holographic display with a uniform 3D Gerchberg‐Saxton algorithm." SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.

+ + +

Parameters:

+
    +
  • + fields + – +
    +
               Complex fields (MxN).
    +
    +
    +
  • +
  • + distances + – +
    +
               Propagation distances.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + slm_range + – +
    +
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
  • + initial_phase + – +
    +
               Phase to be added to the initial value.
    +
    +
    +
  • +
  • + target_type + – +
    +
               Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. "Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg–Saxton algorithm." Applied optics 54.23 (2015): 6994-7001.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +hologram ( complex +) – +
    +

    Calculated complex hologram.

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None):
+    """
+    Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. "30.4: Multi‐plane holographic display with a uniform 3D Gerchberg‐Saxton algorithm." SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.
+
+    Parameters
+    ----------
+    fields           : np.complex64
+                       Complex fields (MxN).
+    distances        : list
+                       Propagation distances.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    slm_range        : float
+                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).
+    initial_phase    : np.complex64
+                       Phase to be added to the initial value.
+    target_type      : str
+                       Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. "Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg–Saxton algorithm." Applied optics 54.23 (2015): 6994-7001. 
+
+    Returns
+    -------
+    hologram         : np.complex
+                       Calculated complex hologram.
+    """
+    k = wavenumber(wavelength)
+    targets = calculate_amplitude(np.asarray(fields)).astype(np.float64)
+    hologram = generate_complex_field(np.ones(targets[0].shape), 0)
+    hologram = zero_pad(hologram)
+    if type(initial_phase) == type(None):
+        hologram = add_random_phase(hologram)
+    else:
+        initial_phase = zero_pad(initial_phase)
+        hologram = add_phase(hologram, initial_phase)
+    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]
+    orig_shape = [int(fields[0].shape[0]/2.), int(fields[0].shape[1]/2.)]
+    holograms = np.zeros(
+        (len(distances), hologram.shape[0], hologram.shape[1]), dtype=np.complex64)
+    for i in tqdm(range(n_iterations), leave=False):
+        for distance_id in tqdm(range(len(distances)), leave=False):
+            distance = distances[distance_id]
+            reconstruction = propagate_beam(
+                hologram, k, distance, dx, wavelength, propagation_type)
+            if target_type == 'double constraint':
+                if type(coefficients) == type(None):
+                    raise Exception(
+                        "Provide coeeficients of alpha,beta and gamma for double constraint.")
+                alpha = coefficients[0]
+                beta = coefficients[1]
+                gamma = coefficients[2]
+                target_current = 2*alpha * \
+                    np.copy(targets[distance_id])-beta * \
+                    calculate_amplitude(reconstruction)
+                target_current[target_current == 0] = gamma * \
+                    np.abs(reconstruction[target_current == 0])
+            elif target_type == 'no constraint':
+                target_current = np.abs(targets[distance_id])
+            new_target = calculate_amplitude(reconstruction)
+            new_target[
+                center[0]-orig_shape[0]:center[0]+orig_shape[0],
+                center[1]-orig_shape[1]:center[1]+orig_shape[1]
+            ] = target_current
+            reconstruction = generate_complex_field(
+                new_target, calculate_phase(reconstruction))
+            hologram_layer = propagate_beam(
+                reconstruction, k, -distance, dx, wavelength, propagation_type)
+            hologram_layer = generate_complex_field(
+                1., calculate_phase(hologram_layer))
+            hologram_layer = hologram_layer[
+                center[0]-orig_shape[0]:center[0]+orig_shape[0],
+                center[1]-orig_shape[1]:center[1]+orig_shape[1]
+            ]
+            hologram_layer = zero_pad(hologram_layer)
+            holograms[distance_id] = hologram_layer
+        hologram = np.sum(holograms, axis=0)
+    hologram = hologram[
+        center[0]-orig_shape[0]:center[0]+orig_shape[0],
+        center[1]-orig_shape[1]:center[1]+orig_shape[1]
+    ]
+    return hologram
+
+
+
+ +
+ +
+ + +

+ impulse_response_fresnel(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate impulse response based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def impulse_response_fresnel(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate impulse response based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nu / 2 * dx, nu / 2 * dx, nu)
+    y = np.linspace(-nv / 2 * dx, nv / 2 * dx, nv)
+    X, Y = np.meshgrid(x, y)
+    h = 1. / (1j * wavelength * distance) * np.exp(1j * k / (2 * distance) * (X ** 2 + Y ** 2))
+    H = np.fft.fft2(np.fft.fftshift(h))
+    U1 = np.fft.fft2(np.fft.fftshift(field))
+    U2 = H * U1
+    result = np.fft.ifftshift(np.fft.ifft2(U2))
+    result = np.roll(result, shift = (1, 1), axis = (0, 1))
+    return result
+
+
+
+ +
+ +
+ + +

+ propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel') + +

+ + +
+ +

Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with "Computational Fourier Optics" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see "Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields".

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
  • + propagation_type + (str, default: + 'IR Fresnel' +) + – +
    +
               Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
def propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel'):
+    """
+    Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with "Computational Fourier Optics" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see "Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields".
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+    propagation_type : str
+                       Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    if propagation_type == 'Rayleigh-Sommerfeld':
+        result = rayleigh_sommerfeld(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Angular Spectrum':
+        result = angular_spectrum(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Impulse Response Fresnel':
+        result = impulse_response_fresnel(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Bandlimited Angular Spectrum':
+        result = band_limited_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Bandextended Angular Spectrum':
+        result = band_extended_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Adaptive Sampling Angular Spectrum':
+        result = adaptive_sampling_angular_spectrum(
+            field, k, distance, dx, wavelength)
+    elif propagation_type == 'Transfer Function Fresnel':
+        result = transfer_function_fresnel(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Fraunhofer':
+        result = fraunhofer(field, k, distance, dx, wavelength)
+    elif propagation_type == 'Fraunhofer Inverse':
+        result = fraunhofer_inverse(field, k, distance, dx, wavelength)
+    else:
+        raise Exception("Unknown propagation type selected.")
+    return result
+
+
+
+ +
+ +
+ + +

+ rayleigh_sommerfeld(field, k, distance, dx, wavelength) + +

+ + +
+ +

Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def rayleigh_sommerfeld(field, k, distance, dx, wavelength):
+    """
+    Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+    """
+    nv, nu = field.shape
+    x = np.linspace(-nv * dx / 2, nv * dx / 2, nv)
+    y = np.linspace(-nu * dx / 2, nu * dx / 2, nu)
+    X, Y = np.meshgrid(x, y)
+    Z = X ** 2 + Y ** 2
+    result = np.zeros(field.shape, dtype=np.complex64)
+    direction = int(distance/np.abs(distance))
+    for i in range(nu):
+        for j in range(nv):
+            if field[i, j] != 0:
+                r01 = np.sqrt(distance ** 2 + (X - X[i, j]) ** 2 + (Y - Y[i, j]) ** 2) * direction
+                cosnr01 = np.cos(distance / r01)
+                result += field[i, j] * np.exp(1j * k * r01) / r01 * cosnr01
+    result *= 1. / (1j * wavelength)
+    return result
+
+
+
+ +
+ +
+ + +

+ transfer_function_fresnel(field, k, distance, dx, wavelength) + +

+ + +
+ +

A definition to calculate convolution based Fresnel approximation for beam propagation.

+ + +

Parameters:

+
    +
  • + field + – +
    +
               Complex field (MxN).
    +
    +
    +
  • +
  • + k + – +
    +
               Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + distance + – +
    +
               Propagation distance.
    +
    +
    +
  • +
  • + dx + – +
    +
               Size of one single pixel in the field grid (in meters).
    +
    +
    +
  • +
  • + wavelength + – +
    +
               Wavelength of the electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +result ( complex +) – +
    +

    Final complex field (MxN).

    +
    +
  • +
+ +
+ Source code in odak/wave/classical.py +
def transfer_function_fresnel(field, k, distance, dx, wavelength):
+    """
+    A definition to calculate convolution based Fresnel approximation for beam propagation.
+
+    Parameters
+    ----------
+    field            : np.complex
+                       Complex field (MxN).
+    k                : odak.wave.wavenumber
+                       Wave number of a wave, see odak.wave.wavenumber for more.
+    distance         : float
+                       Propagation distance.
+    dx               : float
+                       Size of one single pixel in the field grid (in meters).
+    wavelength       : float
+                       Wavelength of the electric field.
+
+    Returns
+    -------
+    result           : np.complex
+                       Final complex field (MxN).
+
+    """
+    nv, nu = field.shape
+    fx = np.linspace(-1. / 2. /dx, 1. /2. /dx, nu)
+    fy = np.linspace(-1. / 2. /dx, 1. /2. /dx, nv)
+    FX, FY = np.meshgrid(fx, fy)
+    H = np.exp(1j * k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5)
+    U1 = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(field)))
+    U2 = H * U1
+    result = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(U2)))
+    return result
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ double_convergence(nx, ny, k, r, dx) + +

+ + +
+ +

A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. "Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm." Optics express 26.8 (2018): 10140-10151.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + r + – +
    +
         The distance between location of a light source and an image plane.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +function ( ndarray +) – +
    +

    Generated phase pattern for a Gerchberg-Saxton method.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
def double_convergence(nx, ny, k, r, dx):
+    """
+    A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. "Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm." Optics express 26.8 (2018): 10140-10151.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    r          : float
+                 The distance between location of a light source and an image plane.
+    dx         : float
+                 Pixel pitch.
+
+    Returns
+    -------
+    function   : ndarray
+                 Generated phase pattern for a Gerchberg-Saxton method.
+    """
+    size = [ny, nx]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    w = np.exp(1j*k*Z/r)
+    return w
+
+
+
+ +
+ +
+ + +

+ linear_grating(nx, ny, every=2, add=3.14, axis='x') + +

+ + +
+ +

A definition to generate a linear grating.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + every + – +
    +
         Add the add value at every given number.
    +
    +
    +
  • +
  • + add + – +
    +
         Angle to be added.
    +
    +
    +
  • +
  • + axis + – +
    +
         Axis eiter X,Y or both.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( ndarray +) – +
    +

    Linear grating term.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
def linear_grating(nx, ny, every=2, add=3.14, axis='x'):
+    """
+    A definition to generate a linear grating.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    every      : int
+                 Add the add value at every given number.
+    add        : float
+                 Angle to be added.
+    axis       : string
+                 Axis eiter X,Y or both.
+
+    Returns
+    -------
+    field      : ndarray
+                 Linear grating term.
+    """
+    grating = np.zeros((nx, ny), dtype=np.complex64)
+    if axis == 'x':
+        grating[::every, :] = np.exp(1j*add)
+    if axis == 'y':
+        grating[:, ::every] = np.exp(1j*add)
+    if axis == 'xy':
+        checker = np.indices((nx, ny)).sum(axis=0) % every
+        checker += 1
+        checker = checker % 2
+        grating = np.exp(1j*checker*add)
+    return grating
+
+
+
+ +
+ +
+ + +

+ prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x') + +

+ + +
+ +

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + angle + – +
    +
         Tilt angle of the prism in degrees.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
  • + axis + – +
    +
         Axis of the prism.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +prism ( ndarray +) – +
    +

    Generated phase function for a prism.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
def prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x'):
+    """
+    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    angle      : float
+                 Tilt angle of the prism in degrees.
+    dx         : float
+                 Pixel pitch.
+    axis       : str
+                 Axis of the prism.
+
+    Returns
+    -------
+    prism      : ndarray
+                 Generated phase function for a prism.
+    """
+    angle = np.radians(angle)
+    size = [ny, nx]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])
+    X, Y = np.meshgrid(x, y)
+    if axis == 'y':
+        prism = np.exp(-1j*k*np.sin(angle)*Y)
+    elif axis == 'x':
+        prism = np.exp(-1j*k*np.sin(angle)*X)
+    return prism
+
+
+
+ +
+ +
+ + +

+ quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]) + +

+ + +
+ +

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

+ + +

Parameters:

+
    +
  • + nx + – +
    +
         Size of the output along X.
    +
    +
    +
  • +
  • + ny + – +
    +
         Size of the output along Y.
    +
    +
    +
  • +
  • + k + – +
    +
         See odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + focal + – +
    +
         Focal length of the quadratic phase function.
    +
    +
    +
  • +
  • + dx + – +
    +
         Pixel pitch.
    +
    +
    +
  • +
  • + offset + – +
    +
         Deviation from the center along X and Y axes.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +function ( ndarray +) – +
    +

    Generated quadratic phase function.

    +
    +
  • +
+ +
+ Source code in odak/wave/lens.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):
+    """ 
+    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.
+
+    Parameters
+    ----------
+    nx         : int
+                 Size of the output along X.
+    ny         : int
+                 Size of the output along Y.
+    k          : odak.wave.wavenumber
+                 See odak.wave.wavenumber for more.
+    focal      : float
+                 Focal length of the quadratic phase function.
+    dx         : float
+                 Pixel pitch.
+    offset     : list
+                 Deviation from the center along X and Y axes.
+
+    Returns
+    -------
+    function   : ndarray
+                 Generated quadratic phase function.
+    """
+    size = [nx, ny]
+    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])-offset[1]*dx
+    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])-offset[0]*dx
+    X, Y = np.meshgrid(x, y)
+    Z = X**2+Y**2
+    qwf = np.exp(1j*k*0.5*np.sin(Z/focal))
+    return qwf
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ calculate_amplitude(field) + +

+ + +
+ +

Definition to calculate amplitude of a single or multiple given electric field(s).

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Electric fields or an electric field.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +amplitude ( float +) – +
    +

    Amplitude or amplitudes of electric field(s).

    +
    +
  • +
+ +
+ Source code in odak/wave/utils.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def calculate_amplitude(field):
+    """ 
+    Definition to calculate amplitude of a single or multiple given electric field(s).
+
+    Parameters
+    ----------
+    field        : ndarray.complex or complex
+                   Electric fields or an electric field.
+
+    Returns
+    -------
+    amplitude    : float
+                   Amplitude or amplitudes of electric field(s).
+    """
+    amplitude = np.abs(field)
+    return amplitude
+
+
+
+ +
+ +
+ + +

+ calculate_phase(field, deg=False) + +

+ + +
+ +

Definition to calculate phase of a single or multiple given electric field(s).

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Electric fields or an electric field.
    +
    +
    +
  • +
  • + deg + – +
    +
           If set True, the angles will be returned in degrees.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +phase ( float +) – +
    +

    Phase or phases of electric field(s) in radians.

    +
    +
  • +
+ +
+ Source code in odak/wave/utils.py +
 4
+ 5
+ 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
def calculate_phase(field, deg=False):
+    """ 
+    Definition to calculate phase of a single or multiple given electric field(s).
+
+    Parameters
+    ----------
+    field        : ndarray.complex or complex
+                   Electric fields or an electric field.
+    deg          : bool
+                   If set True, the angles will be returned in degrees.
+
+    Returns
+    -------
+    phase        : float
+                   Phase or phases of electric field(s) in radians.
+    """
+    phase = np.angle(field)
+    if deg == True:
+        phase *= 180./np.pi
+    return phase
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0) + +

+ + +
+ +

Definition to return state of a plane wave at a particular distance and time.

+ + +

Parameters:

+
    +
  • + amplitude + – +
    +
           Amplitude of a wave.
    +
    +
    +
  • +
  • + opd + – +
    +
           Optical path difference in mm.
    +
    +
    +
  • +
  • + k + – +
    +
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.
    +
    +
    +
  • +
  • + phase + – +
    +
           Initial phase of a wave.
    +
    +
    +
  • +
  • + w + – +
    +
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
    +
    +
    +
  • +
  • + t + – +
    +
           Time in seconds.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field ( complex +) – +
    +

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
def electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0):
+    """
+    Definition to return state of a plane wave at a particular distance and time.
+
+    Parameters
+    ----------
+    amplitude    : float
+                   Amplitude of a wave.
+    opd          : float
+                   Optical path difference in mm.
+    k            : float
+                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.
+    phase        : float
+                   Initial phase of a wave.
+    w            : float
+                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
+    t            : float
+                   Time in seconds.
+
+    Returns
+    -------
+    field        : complex
+                   A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).
+    """
+    field = amplitude*np.exp(1j*(-w*t+opd*k+phase))/opd**2
+    return field
+
+
+
+ +
+ +
+ + +

+ propagate_field(points0, points1, field0, wave_number, direction=1) + +

+ + +
+ +

Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.

+ + +

Parameters:

+
    +
  • + points0 + – +
    +
            Start points (i.e. odak.tools.grid_sample).
    +
    +
    +
  • +
  • + points1 + – +
    +
            End points (ie. odak.tools.grid_sample).
    +
    +
    +
  • +
  • + field0 + – +
    +
            Field for given starting points.
    +
    +
    +
  • +
  • + wave_number + – +
    +
            Wave number of a wave, see odak.wave.wavenumber for more.
    +
    +
    +
  • +
  • + direction + – +
    +
            For propagating in forward direction set as 1, otherwise -1.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +field1 ( ndarray +) – +
    +

    Field for given end points.

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
 6
+ 7
+ 8
+ 9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def propagate_field(points0, points1, field0, wave_number, direction=1):
+    """
+    Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.
+
+    Parameters
+    ----------
+    points0       : ndarray
+                    Start points (i.e. odak.tools.grid_sample).
+    points1       : ndarray
+                    End points (ie. odak.tools.grid_sample).
+    field0        : ndarray
+                    Field for given starting points.
+    wave_number   : float
+                    Wave number of a wave, see odak.wave.wavenumber for more.
+    direction     : float
+                    For propagating in forward direction set as 1, otherwise -1.
+
+    Returns
+    -------
+    field1        : ndarray
+                    Field for given end points.
+    """
+    field1 = np.zeros(points1.shape[0], dtype=np.complex64)
+    for point_id in range(points0.shape[0]):
+        point = points0[point_id]
+        distances = distance_between_two_points(
+            point,
+            points1
+        )
+        field1 += electric_field_per_plane_wave(
+            calculate_amplitude(field0[point_id]),
+            distances*direction,
+            wave_number,
+            phase=calculate_phase(field0[point_id])
+        )
+    return field1
+
+
+
+ +
+ +
+ + +

+ propagate_plane_waves(field, opd, k, w=0, t=0) + +

+ + +
+ +

Definition to propagate a field representing a plane wave at a particular distance and time.

+ + +

Parameters:

+
    +
  • + field + – +
    +
           Complex field.
    +
    +
    +
  • +
  • + opd + – +
    +
           Optical path difference in mm.
    +
    +
    +
  • +
  • + k + – +
    +
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.
    +
    +
    +
  • +
  • + w + – +
    +
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
    +
    +
    +
  • +
  • + t + – +
    +
           Time in seconds.
    +
    +
    +
  • +
+ + +

Returns:

+
    +
  • +new_field ( complex +) – +
    +

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

    +
    +
  • +
+ +
+ Source code in odak/wave/vector.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
def propagate_plane_waves(field, opd, k, w=0, t=0):
+    """
+    Definition to propagate a field representing a plane wave at a particular distance and time.
+
+    Parameters
+    ----------
+    field        : complex
+                   Complex field.
+    opd          : float
+                   Optical path difference in mm.
+    k            : float
+                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.
+    w            : float
+                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.
+    t            : float
+                   Time in seconds.
+
+    Returns
+    -------
+    new_field     : complex
+                    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).
+    """
+    new_field = field*np.exp(1j*(-w*t+opd*k))/opd**2
+    return new_field
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/perception/index.html b/perception/index.html new file mode 100644 index 00000000..9270aa11 --- /dev/null +++ b/perception/index.html @@ -0,0 +1,1785 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Getting Started - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Visual perception

+

The perception module of odak focuses on visual perception, and in particular on gaze-contingent perceptual loss functions.

+

Metamers

+

It contains an implementation of a metameric loss function. When used in optimisation tasks, this loss function enforces the optimised image to be a ventral metamer to the ground truth image.

+

This loss function is based on previous work on fast metamer generation. It uses the same statistical model and many of the same acceleration techniques (e.g. MIP map sampling) to enable the metameric loss to run efficiently.

+

Engineering notes

+ + + + + + + + + + + + + +
NoteDescription
Using metameric loss in OdakThis engineering note will give you an idea about how to use the metameric perceptual loss in Odak.
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/raytracing/index.html b/raytracing/index.html new file mode 100644 index 00000000..6af48d41 --- /dev/null +++ b/raytracing/index.html @@ -0,0 +1,1693 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Introduction - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Raytracing

+

Odak provides a set of function that implements methods used in raytracing. +The ones implemented in Numpy, such as odak.raytracing, are not differentiable. +However, the ones impelemented in Torch, such as odak.learn.raytracing, are differentiable.

+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/search/search_index.json b/search/search_index.json new file mode 100644 index 00000000..834db426 --- /dev/null +++ b/search/search_index.json @@ -0,0 +1 @@ +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Getting started","text":"

Informative

Odak (pronounced \"O-dac\") is the fundamental library for scientific computing in optical sciences, computer graphics, and visual perception. We designed this page to help first-time users, new contributors, and existing users understand where to go within this documentation when they need help with certain aspects of Odak. If you think you need a refresher or are a beginner willing to learn more about light and computation, we created an entire course named Computational Light for you to get to pace with the computational aspects of light.

"},{"location":"#absolute-beginners","title":"Absolute Beginners","text":"

Informative \u00b7 Practical

Computational Light Course: Learn Odak and Physics of Light

"},{"location":"#new-users","title":"New Users","text":"

Informative

  • What is Odak?
  • Installation
"},{"location":"#use-cases","title":"Use cases","text":"

Informative

  • Computer-generated holography
  • General toolkit
  • Optical Raytracing
  • Machine Learning
  • Visual Perception
  • Lensless Cameras
"},{"location":"#new-contributors","title":"New contributors","text":"

Informative

  • Contributing to Odak
"},{"location":"#additional-information","title":"Additional information","text":"

Informative

  • Citing Odak in a scientific publication using Zenodo
  • License of Odak
  • Reporting bugs or requesting a feature

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

"},{"location":"beginning/","title":"What is Odak?","text":"

Odak (pronounced \"O-dac\") is the fundamental library for scientific computing in optical sciences, computer graphics and visual perception.

"},{"location":"beginning/#why-does-it-exist","title":"Why does it exist?","text":"

This question has two answers. One of them is related to the history of Odak, which is partially answered in the next section. The other answer lies in what kind of submodules Odak has in it. Depending on a need of a scientist at all levels or a professional from the industry, these submodules can help the design processes in optics and visual perception.

Odak includes modules for geometric 3D raytracing, Jones calculus, wave optics, and a set of tools to ease pain in measurement, exporting/importing CAD, and visualization during a design process. We have generated a set of recipes that go well with machine learning approaches compatible with the PyTorch learning framework as provided here. We have created many test scripts to inspire how you use Odak and helping your design process. Finally, we have created a distribution system to process tasks in parallel across multiple computing resources within the same network. Odak can either run using CPUs or automatically switch to NVIDIA GPUs.

"},{"location":"beginning/#history","title":"History","text":"

In the summer of 2011, I, Kaan Ak\u015fit, was a PhD student. At the time, I had some understanding of the Python programming language, and I created my first Python based computer game using pygame, a fantastic library, over a weekend in 2009. I was actively using Python to deploy packages for the Linux distribution that I supported at the time, Pardus. Meantime, that summer, I didn't have any internship or any vital task that I had to complete. I was super curious about the internals of the optical design software that I used at the time, ZEMAX. All of this lead to an exciting never-ending excursion that I still enjoy to this day, which I named Odak. Odak means focus in Turkish, and pronounced as O-dac.

The very first paper I read to build the pieces of Odak was General Ray tracing procedure\" from G.H. Spencer and M.V.R.K Murty, an article on routines for raytracing, published at the Journal of the Optical Society of America, Issue 6, Volume 52, Page 672. It helped to add reflection and refraction functions required in a raytracing routine. I continuously add to Odak over my entire professional life. That little raytracing program I wrote in 2011 is now a vital library for my research, and much more than a raytracer.

I can write pages and pages about what happened next. You can accurately estimate what happened next by checking my website and my cv. But I think the most critical part is always the beginning as it can inspire many other people to follow their thoughts and build their own thing! I used Odak in my all published papers. When I look back, I can only say that I am thankful to 2011 me spending a part of his summer in front of a computer to code a raytracer for optical design. Odak is now more than a raytracer, expanding on many other aspects of light, including vision science, polarization optics, computer-generated holography or machine learning routines for light sciences. Odak keeps on growing thanks to a body of people that contributed over time. I will keep it growing in the future and will continually transform into the tool that I need to innovate. All of it is free as in free-free, and all is sharable as I believe in people.

"},{"location":"cgh/","title":"Computer-Generated Holography","text":"

Odak contains essential ingredients for research and development targeting Computer-Generated Holography. We consult the beginners in this matter to Goodman's Introduction to Fourier Optics book (ISBN-13: 978-0974707723) and Principles of optics: electromagnetic theory of propagation, interference and diffraction of light from Max Born and Emil Wolf (ISBN 0-08-26482-4). In the rest of this document, you will find engineering notes and relevant functions in Odak that helps you describing complex nature of light on a computer. Note that, the creators of this documentation are from Computational Displays domain, however the provided submodules can potentially aid other lines of research as well, such as Computational Imaging or Computational Microscopy.

"},{"location":"cgh/#engineering-notes","title":"Engineering notes","text":"Note Description Holographic light transport This engineering note will give you an idea about how coherent light propagates in free space. Optimizing phase-only single plane holograms using Odak This engineering note will give you an idea about how to calculate phase-only holograms using Odak. Learning the model of a holographic display This link navigates to a project website that provides a codebase that can learn the model of a holographic display using a single complex kernel. Optimizing three-dimensional multiplane holograms using Odak This link navigates to a project website that provides a codebase that can help optimize a phase-only hologram representing multiplanar three-dimensional scenes."},{"location":"contributing/","title":"Contributing to Odak","text":"

Odak is in constant development. We shape Odak according to the most current needs in our scientific research. We welcome both users and developers in the open-source community as long as they have good intentions (e.g., scientific research). For the most recent description of Odak, please consult our description. If you are planning to use Odak for industrial purposes, please reach out to Kaan Ak\u015fit. All of the Odak contributors are listed in our THANKS.txt and added to CITATION.cff regardless of how much they contribute to the project. Their names are also included in our Digital Object Identifier (DOI) page.

"},{"location":"contributing/#contributing-process","title":"Contributing process","text":"

Contributions to Odak can come in different forms. It can either be code or documentation related contributions. Historically, Odak has evolved through scientific collaboration, in which authors of Odak identified a collaborative project with a new potential contributor. You can always reach out to Kaan Ak\u015fit to query your idea for potential collaborations in the future. Another potential place to identify likely means to improve odak is to address outstanding issues of Odak.

"},{"location":"contributing/#code","title":"Code","text":"

Odak's odak directory contains the source code. To add to it, please make sure that you can install and test Odak on your local computer. The installation documentation contains routines for installation and testing, please follow that page carefully.

We typically work with pull requests. If you want to add new code to Odak, please do not hesitate to fork Odak's git repository and have your modifications on your fork at first. Once you test the modified version, please do not hesitate to initiate a pull request. We will revise your code, and if found suitable, it will be merged to the master branch. Remember to follow numpy convention while adding documentation to your newly added functions to Odak. Another thing to mention is regarding to the code quality and standard. Although it hasn't been strictly followed since the start of Odak, note that Odak follows code conventions of flake8, which can be installed using:

pip3 install flake8\n

You can always check for code standard violations in Odak by running these two commands:

flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics\nflake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics\n

There are tools that can automatically fix code in terms of following standards. One primary tool that we are aware of is autopep8, which can be installed using:

pip3 install autopep8\n

Please once you are ready to have a pull request, make sure to add a unit test for your additions in test folder, and make sure to test all unit tests by running pytest. If your system do not have pytest installed, it can be installed using:

pip3 install pytest\n
"},{"location":"contributing/#documentation","title":"Documentation","text":"

Under Odak's source's root directory, you will find a folder named docs. This directory contains all the necessary information to generate the pages in this documentation. If you are interested in improving the documentation of Odak, this directory is the place where you will be adding things.

Odak's documentation is built using mkdocs. At this point, I assume that you have successfully installed Odak on your system. If you haven't yet, please follow installation documentation. To be able to run documentation locally, make sure to have the correct dependencies installed properly:

pip3 install plyfile\npip3 install Pillow\npip3 install tqdm\npip3 install mkdocs-material\npip3 install mkdocstrings\n

Once you have dependencies appropriately installed, navigate to the source directory of Odak in your hard drive and run a test server:

cd odak\nmkdocs serve\n

If all goes well, you should see a bunch of lines on your terminal, and the final lines should look similar to these:

INFO     -  Documentation built in 4.45 seconds\nINFO     -  [22:15:22] Serving on http://127.0.0.1:8000/odak/\nINFO     -  [22:15:23] Browser connected: http://127.0.0.1:8000/odak/\n

At this point, you can start your favourite browser and navigate to http://127.0.0.1:8000/odak to view documentation locally. This local viewing is essential as it can help you view your changes locally on the spot before actually committing. One last thing to mention here is the fact that Odak's docs folder's structure is self-explanatory. It follows markdown rules, and mkdocsstrings style is numpy.

"},{"location":"installation/","title":"Installation","text":"

We use odak with Linux operating systems. Therefore, we don't know if it can work with Windows or Mac operating systems. Odak can be installed in multiple ways. However, our recommended method for installing Odak is using pip distribution system. We update Odak within pip with each new version. Thus, the most straightforward way to install Odak is to use the below command in a Linux shell:

pip3 install odak\n
Note that Odak is in constant development. One may want to install the latest and greatest odak in the source repository for their reasons. In this case, our recommended method is to rely on pip for installing Odak from the source using:

pip3 install git+https://github.com/kaanaksit/odak\n

One can also install Odak without pip by first getting a local copy and installing using Python. Such an installation can be conducted using:

git clone git@github.com:kaanaksit/odak.git\ncd odak\npip3 install -r requirements.txt\npip3 install -e .\n
"},{"location":"installation/#uninstalling-the-development-version","title":"Uninstalling the Development version","text":"

If you have to remove the development version of odak, you can first try:

pip3 uninstall odak\nsudo pip3 uninstall odak\n

And if for some reason, you are still able to import odak after that, check easy-install.pth file which is typically found ~/.local/lib/pythonX/site-packages, where ~ refers to your home directory and X refers to your Python version. In that file, if you see odak's directory listed, delete it. This will help you remove development version of odak.

"},{"location":"installation/#notes-before-running","title":"Notes before running","text":"

Some notes should be highlighted to users, and these include:

  • Odak installs PyTorch that only uses CPU. To properly install PyTorch with GPU support, please consult PyTorch website.
"},{"location":"installation/#testing-an-installation","title":"Testing an installation","text":"

After installing Odak, one can test if Odak has been appropriately installed with its dependencies by running the unit tests. To be able to run unit tests, make sure to have pytest installed:

pip3 install -U pytest\n

Once pytest is installed, unit tests can be run by calling:

cd odak\npytest\n
The tests should return no error. However, if an error is encountered, please start a new issue to help us be aware of the issue.

"},{"location":"lensless/","title":"Lensless Imaging","text":"

Odak contains essential ingredients for research and development targeting Lensless Imaging.

"},{"location":"machine_learning/","title":"Machine learning","text":"

Odak provides a set of function that implements classical methods in machine learning. Note that these functions are typically basing on Numpy. Thus, they do not take advantage from automatic differentiation found in Torch. The soul reason why these functions exists is because they stand as an example for impelementing basic methods in machine learning.

"},{"location":"perception/","title":"Visual perception","text":"

The perception module of odak focuses on visual perception, and in particular on gaze-contingent perceptual loss functions.

"},{"location":"perception/#metamers","title":"Metamers","text":"

It contains an implementation of a metameric loss function. When used in optimisation tasks, this loss function enforces the optimised image to be a ventral metamer to the ground truth image.

This loss function is based on previous work on fast metamer generation. It uses the same statistical model and many of the same acceleration techniques (e.g. MIP map sampling) to enable the metameric loss to run efficiently.

"},{"location":"perception/#engineering-notes","title":"Engineering notes","text":"Note Description Using metameric loss in Odak This engineering note will give you an idea about how to use the metameric perceptual loss in Odak."},{"location":"raytracing/","title":"Raytracing","text":"

Odak provides a set of function that implements methods used in raytracing. The ones implemented in Numpy, such as odak.raytracing, are not differentiable. However, the ones impelemented in Torch, such as odak.learn.raytracing, are differentiable.

"},{"location":"toolkit/","title":"General toolkit.","text":"

Odak provides a set of functions that can be used for general purpose work, such as saving an image file or loading a three-dimensional point cloud of an object. These functions are helpful for general use and provide consistency across routine works in loading and saving routines. When working with odak, we strongly suggest sticking to the general toolkit to provide a coherent solution to your task.

"},{"location":"toolkit/#engineering-notes","title":"Engineering notes","text":"Note Description Working with images This engineering note will give you an idea about how read and write images using odak. Working with dictionaries This engineering note will give you an idea about how read and write dictionaries using odak."},{"location":"course/","title":"Prerequisites and general information","text":"Narrate section"},{"location":"course/#prerequisites-and-general-information","title":"Prerequisites and general information","text":"

You have reached the website for the Computational Light Course.

This page is the starting point for the Computational Light course. Readers can follow the course material found on these pages to learn more about the field of Computational Light. I encourage readers to carefully read this page to decide if they want to continue with the course.

"},{"location":"course/#brief-course-description","title":"Brief course description","text":"

Computational Light is a term that brings the concepts in computational methods with the characteristics of light. In other words, wherever we can program the qualities of light, such as its intensity or direction, this will get us into the topics of Computational Light. Some well-known subfields of Computational Light are Computer Graphics, Computational Displays, Computational Photography, Computational Imaging and Sensing, Computational Optics and Fabrication, Optical Communication, and All-optical Machine Learning.

Future is yet to be decided. Will you help me build it? A rendering from Telelife vision paper.

1

Computational Light Course bridges the gap between Computer Science and physics. In other words, Computational Light Course offers students a gateway to get familiar with various aspects of the physics of light, the human visual system, computational methods in designing light-based devices, and applications of light. Precisely, students will familiarize themselves with designing and implementing graphics, display, sensor, and camera systems using state-of-the-art deep learning and optimization methods. A deep understanding of these topics can help students become experts in the computational design of new graphics, displays, sensors, and camera systems.

"},{"location":"course/#prerequisites","title":"Prerequisites","text":"

These are the prerequisites of Computational Light course:

  • Background knowledge. First and foremost being fluent in programming with Python programming language and a graduate-level understanding of Linear Algebra, and Machine Learning are highly required.
  • Skills and abilities. Throughout the entire course, three libraries will be used, and these libraries include odak, numpy, and torch. Familiarity with these libraries is a big plus.
  • Required Resources. Readers need a computer with decent computational resources (e.g., GPU) when working on the provided materials, laboratory work, and projects. In case you do not have the right resources, consider using Google's Colab service as it is free to students. Note that at each section of the course, you will be provided with relevant reading materials on the spot.
  • Expectations. Readers also need sustainable motivation to learn new things related to the topics of Computational Light, and willing to advance the field by developing, innovating and researching. In other terms, you are someone motivated to create a positive impact in the society with light related innovations. You can also be someone eager to understand and learn physics behind light and how you can simulate light related phenomena.
"},{"location":"course/#questions-and-answers","title":"Questions and Answers","text":"

Here are some questions and answers related to the course that readers may ask:

What is the overarching rationale for the module?

Historically, physics and electronics departments in various universities study and teach the physics of light. This way, new light-based devices and equipment have been invented, such as displays, cameras, and fiber networks, in the past, and these devices continuously serve our societies. However, emerging topics from mathematics and computer science departments, such as deep learning and advanced optimization methods, unlocked new capabilities for existing light-based devices and started to play a crucial role in designing the next generation of these devices. The Computational Light Course aims to bridge this gap between Computer Science and physics by providing a fundamental understanding of light and computational methods that helps to explore new possibilities with light.

Who is the target audience of Computational Light course?

The Computational Light course is designed for individuals willing to learn how to develop and invent light-based practical systems for next-generation human-computer interfaces. This course targets a graduate-level audience in Computer Science, Physics and Electrical and Electronics Engineering departments. However, you do not have to be strictly from one of the highlighted targeted audiences. Simply put, if you think you can learn and are eager to learn, no one will stop you.

How can I learn Python programming, linear Algebra and machine learning?

There isn't a point in providing references on how to learn Python programming, Linear Algebra, and Machine Learning as there is a vast amount of resources online or in your previous university courses. Your favorite search engine is your friend in this case.

How do I install Python, numpy and torch?

The installation guide for python, numpy and torch is also available on their websites.

How do I install odak?

Odak's installation page and README provide the most up-to-date information on installing odak. But in a nutshell, all you need is to use the following command in a terminal pip3 install odak for the latest version, or if you want to install the latest code from the source, use pip3 install git+https://github.com/kaanaksit/odak.

Which Python environment and operating system should I use?

I use the Python distribution shipped with a traditional Linux distribution (e.g., Ubuntu). Again, there isn't no one correct answer here for everyone. You can use any operating system (e.g., Windows, Mac) and Python distribution (e.g., conda).

Which text editor should I use for programming?

I use vim as my text editor. However, I understand that vim could be challenging to adopt, especially as a newcomer. The pattern I observe among collaborators and students is that they use Microsoft's Visual Studio, a competent text editor with artificial intelligence support through subscription and works across various operating systems. I encourage you to make your choice depending on how comfortable you are with sharing your data with companies. Please also remember that I am only making suggestions here. If another text editor works better for you, please use that one (e.g., nano, Sublime Text, Atom, Notepad++, Jupyter Notebooks).

Which terminal program to use?

You are highly encouraged to use the terminal that you feel most comfortable with. This terminal could be the default terminal in your operating system. I use terminator as it enables my workflow with incredible features and is open source.

What is the method of delivery?

The proposed course, Computational Light Course, comprises multiple elements in delivery. We list these elements as the followings:

  • Prerequisites and general information. Students will be provided with a written description of requirements related to the course as in this document.
  • Lectures. The students will attend two hours of classes each week, which will be in-person, virtual, or hybrid, depending on the circumstances (e.g., global pandemic, strikes).
  • Supplementary Lectures. Beyond weekly classes, students will be encouraged to follow several other sources through online video recordings.
  • Background review. Students often need a clear development guideline or a stable production pipeline. Thus, in every class and project, a phase of try-and-error causes the student to lose interest in the topic, and often students need help to pass the stage of getting ready for the course and finding the right recipe to complete their work. Thus, we formulate a special session to review the course's basics and requirements. This way, we hope to overcome the challenges related to the \"warming up\" stage of the class.
  • Lecture content. We will provide the students with a lecture book composed of chapters. These chapters will be discussed at each weekly lecture. The book chapters will be distributed online using Moodle (requires UCL access), and a free copy of this book will also be reachable without requiring UCL access.
  • Laboratory work. Students will be provided with questions about their weekly class topics. These questions will require them to code for a specific task. After each class, students will have an hour-long laboratory session to address these questions by coding. The teaching assistants of the lecture will support each laboratory session.
  • Supporting tools. We continuously develop new tools for the emerging fields of Computational Light. Our development tools will be used in the delivery. These tools are publicly available in Odak, our research toolkit with Mozilla Public License 2.0. Students will get a chance to use these tools in their laboratory works and projects. In the meantime, they will also get the opportunity to contribute to the next versions of the tool.
  • Project Assignments. Students will be evaluated on their projects. The lecturer will suggest projects related to the topics of Computational Light. However, the students will also be highly encouraged to propose projects for successfully finishing their course. These projects are expected to address a research question related to the topic discussed. Thus, there are multiple components of a project. These are implementation in coding, manuscript in a modern paper format, a website to promote the work to wider audiences, and presentation of the work to other students and the lecturer.
  • Office hours. There will be office hours for students willing to talk to the course lecturer, Kaan Ak\u015fit, in a one-on-one setting. Each week, the lecturer will schedule two hours for such cases.
What is the aim of this course?

Computational Light Course aims to train individuals that could potentially help invent and develop the next generation of light-based devices, systems and software. To achieve this goal, Computational Light Course, will aim:

  • To educate students on physics of light, human visual system and computational methods relevant to physics of light based on optimizations and machine learning techniques,
  • To enable students the right set of practical skills in coding and design for the next generation of light-based systems,
  • And to increase literacy on light-based technologies among students and professionals.
What are the intended learning outcomes of this course?

Students who have completed Computational Light Course successfully will have literacy and practical skills on the following items:

  • Physics of Light and applications of Computational Light,
  • Fundamental knowledge of managing a software project (e.g., version and authoring tools, unit tests, coding style, and grammar),
  • Fundamental knowledge of optimization methods and state-of-the-art libraries aiming at relevant topics,
  • Fundamental knowledge of visual perception and the human visual system,
  • Simulating light as geometric rays, continous waves, and quantum level,
  • Simulating imaging and displays systems, including Computer-Generated Holography,
  • Designing and optimizing imaging and display systems,
  • Designing and optimizing all-optical machine learning systems.

Note that the above list is always subject to change in order or topic as society's needs move in various directions.

How to cite this course?

For citing using latex's bibtex bibliography system:

@book{aksit2024computationallight,\n  title = {Computational Light},\n  author = {Ak{\\c{s}}it, Kaan and Kam, Henry},\n  booktitle = {Computational Light Course Notes},\n  year = {2024}\n}\n
For plain text citation: Kaan Ak\u015fit, \"Computational Light Course\", 2024.

"},{"location":"course/#team","title":"Team","text":"

Kaan Ak\u015fit

Instructor

E-mail

Henry Kam

Contributor

E-mail

Contact Us

The preferred way of communication is through the discussions section of odak. Please only reach us through email if the thing you want to achieve, establish, or ask is not possible through the suggested route.

"},{"location":"course/#outreach","title":"Outreach","text":"

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

"},{"location":"course/#acknowledgements","title":"Acknowledgements","text":"

Acknowledgements

We thank our readers. We also thank Yicheng Zhan for his feedback.

Interested in supporting?

Enjoyed our course material and want us to do better in the future? Please consider supporting us monetarily, citing our work in your next scientific work, or leaving us a star for odak.

  1. Jason Orlosky, Misha Sra, Kenan Bekta\u015f, Huaishu Peng, Jeeeun Kim, Nataliya Kos\u2019 myna, Tobias H\u00f6llerer, Anthony Steed, Kiyoshi Kiyokawa, and Kaan Ak\u015fit. Telelife: the future of remote living. Frontiers in Virtual Reality, 2:763340, 2021.\u00a0\u21a9

"},{"location":"course/computational_light/","title":"Light, Computation, and Computational Light","text":"Narrate section"},{"location":"course/computational_light/#light-computation-and-computational-light","title":"Light, Computation, and Computational Light","text":"

We can establish an understanding of the term Computational Light as we explore the term light and its relation to computation.

"},{"location":"course/computational_light/#what-is-light","title":"What is light?","text":"

Informative

Light surrounds us; we see the light and swim in the sea of light. It is indeed a daily matter that we interact by looking out of our window to see what is outside, turning on the lights of a room, looking at our displays, taking pictures of our loved ones, walking in a night lit by moonlight, or downloading media from the internet. Light is an eye-catching festival, reflecting, diffracting, interfering, and refracting. Is light a wave, a ray, or a quantum-related phenomenon? Is light heavy, or weightless? Is light the fastest thing in this universe? Which way does the light go? In a more general sense, how can we use light to trigger more innovations, positively impact our lives, and unlock the mysteries of life? We all experience light, but we must dig deep to describe it clearly.

In this introduction, my first intention here is to establish some basic scientific knowledge about light, which will help us understand why it is essential for the future of technology, especially computing. Note that we will cover more details of light as we make progress through different chapters of this course. But let's get this starting with the following statement. Light is electromagnetic radiation, often described as a bundle of photons, a term first coined by Gilbert Lewis in 1926.

Where can I learn more about electric and magnetic fields?

Beware that the topic of electric and magnetic fields deserves a stand-alone course and has many details to explore. As an undergraduate student, back in the day, I learned about electric and magnetic fields by following a dedicated class and reading this book: Cheng, David Keun. \"Fundamentals of engineering electromagnetics.\" (1993). 1

What is a photon?

Let me adjust this question a bit: What model is good for describing a photon? There is literature describing a photon as a single particle, and works show photons as a pack of particles. Suppose you want a more profound understanding than stating that it is a particle. In that case, you may want to dive deep into existing models in relation to the relativity concept: Roychoudhuri, C., Kracklauer, A. F., & Creath, K. (Eds.). (2017). The nature of light: What is a photon?. CRC Press. 2

Where can I learn more about history of research on light?

There is a website showing noticeable people researching on light since ancient times and their contributions to the research on light. To reach out to this website to get a crash course, click here.

Let me highlight that for anything to be electromagnetic, it must have electric and magnetic fields. Let us start with this simple drawing to explain the characteristics of this electromagnetic radiation, light. Note that this figure depicts a photon at the origin of XYZ axes. But bear in mind that a photon's shape, weight, and characteristics are yet to be fully discovered and remain an open research question. Beware that the figure depicts a photon as a sphere to provide ease of understanding. It does not mean that photons are spheres.

A sketch showing XYZ axes and a photon depicted as a sphere.

Let us imagine that our photon is traveling in the direction of the Z axes (notice \\(\\vec{r}\\), the direction vector). Let us also imagine that this photon has an electric field, \\(\\vec{E}(r,t)\\) oscillating along the Y axes. Typically this electric field is a sinusoidal oscillation following the equation,

\\[ \\vec{E}(r,t) = A cos(wt), \\]

where \\(A\\) is the amplitude of light, \\(t\\) is the time, \\(\\vec{r}\\) is the propagation direction, \\(w\\) is equal to \\(2\\pi f\\) and \\(f\\) represents the frequency of light.

A sketch highligting electric and magnetic fields of light.

A period of this sinusoidal oscillation, \\(\\vec{E}(r, t)\\), describes wavelength of light, \\(\\lambda\\). In the most simple terms, \\(\\lambda\\) is also known as the color of light. As light is electromagnetic, there is one more component than \\(\\vec{E}(r,t)\\) describing light. The next component is the magnetic field, \\(\\vec{B}(r, t)\\). The magnetic field of light, \\(\\vec{B}(r, t)\\), is always perpendicular to the electric field of light, \\(\\vec{E}(r, t)\\) (90 degrees along XY plane). Since only one \\(\\lambda\\) is involved in our example, we call our light monochromatic. This light would have been polychromatic if many other \\(\\lambda\\)s were superimposed to create \\(\\vec{E}(r, t)\\). In other words, monochromatic light is a single-color light, whereas polychromatic light contains many colors. The concept of color originated from how we sense various \\(\\lambda\\)s in nature.

A sketch showing electromagnetic spectrum with waves labelled in terms of their frequencies and temperatures.

But are all electromagnetic waves with various \\(\\lambda\\)s considered as light? The short answer is that we can not call all the electromagnetic radiation light. When we refer to light, we mainly talk about visible light, \\(\\lambda\\)s that our eyes could sense. These \\(\\lambda\\)s defining visible light fall into a tiny portion of the electromagnetic spectrum shown in the above sketch. Mainly, visible light falls into the spectrum covering wavelengths between 380 nm and 750 nm. The tails of visible light in the electromagnetic spectrum, such as near-infrared or ultraviolet regions, could also be referred to as light in some cases (e.g., for camera designers). In this course, although we will talk about visible light, we will also discuss the applications of these regions.

A sketch showing (left) electric and magnetic fields of light, and (right) polarization state of light.

Let us revisit our photon and its electromagnetic field one more time. As depicted in the above figure, the electric field, \\(\\vec{E}(r, t)\\), oscillates along only one axis: the Y axes. The direction of oscillation in \\(\\vec{E}(r, t)\\) is known as polarization of light. In the above example, the polarization of light is linear. In other words, the light is linearly polarized in the vertical axis. Note that when people talk about polarization of light, they always refer to the oscillation direction of the electric field, \\(\\vec{E}(r, t)\\). But are there any other polarization states of light? The light could be polarized in different directions along the X-axis, which would make the light polarized linearly in the horizontal axis, as depicted in the figure below on the left-hand side. If the light has a tilted electric field, \\(\\vec{E}(r, t)\\), with components both in the X and Y axes, light could still be linearly polarized but with some angle. However, if these two components have delays, \\(\\phi\\), in between in terms of oscillation, say one component is \\(\\vec{E_x}(r, t) = A_x cos(wt)\\) and the other component is \\(\\vec{E_y}(r, t) = A_y cos(wt + \\phi)\\), light could have a circular polarization if \\(A_x = A_y\\). But the light will be elliptically polarized if there is a delay, \\(\\phi\\), and \\(A_x \\neq A_y\\). Although we do not discuss this here in detail, note that the delay of \\(\\phi\\) will help steer the light's direction in the Computer-Generated Holography chapter.

A sketch showing (left) various components of polarization, and (right) a right-handed circular polarization as a sample case.

There are means to filter light with a specific polarization as well. Here, we provide a conceptual example. The below sketch depicts a polarization filter like a grid of lines letting the output light oscillate only in a specific direction.

A sketch showing a conceptual example of linear polarization filters.

We should also highlight that light could bounce off surfaces by reflecting or diffusing. If the material is proper (e.g., dielectric mirror), the light will perfectly reflect as depicted in the sketch below on the left-hand side. The light will perfectly diffuse at every angle if the material is proper (e.g., Lambertian diffuser), as depicted in the sketch below on the right-hand side. Though we will discuss these features of light in the Geometric Light chapter in detail, we should also highlight that light could refract through various mediums or diffract through a tiny hole or around a corner.

A sketch showing (left) light's reflection off a dielectric mirror (right) light's diffusion off a Lambertian's surface.

Existing knowledge on our understanding of our universe also states that light is the fastest thing in the universe, and no other material, thing or being could exceed lightspeed (\\(c = 299,792,458\\) metres per second).

\\[ c = \\lambda n f, \\]

where \\(n\\) represents refractive index of a medium that light travels.

Where can I find more basic information about optics and light?

As a graduate student, back in the day, I learned the basics of optics by reading this book without following any course: Hecht, E. (2012). Optics. Pearson Education India. 3

We have identified a bunch of different qualities of light so far. Let us summarize what we have identified in a nutshell.

  • Light is electromagnetic radiation.
  • Light has electric, \\(\\vec{E}(r,t) = A cos(wt)\\), and magnetic fields, \\(\\vec{B}(r,t)\\), that are always perpendicular to each other.
  • Light has color, also known as wavelength, \\(\\lambda\\).
  • When we say light, we typically refer to the color we can see, visible light (390 - 750 nm).
  • The oscillation axis of light's electric field is light's polarization.
  • Light could have various brightness levels, the so-called amplitude of light, \\(A\\).
  • Light's polarization could be at various states with different \\(A\\)s and \\(\\phi\\)s.
  • Light could interfere by accumulating delays, \\(\\phi\\), and this could help change the direction of light.
  • Light could reflect off the surfaces.
  • Light could refract as it changes the medium.
  • Light could diffract around the corners.
  • Light is the fastest thing in our universe.

Remember that the description of light provided in this chapter is simplistic, missing many important details. The reason is to provide an entry and a crash course at first glance is obvious. We will deep dive into focused topics in the following chapters. This way, you will be ready with a conceptual understanding of light.

Lab work: Are there any other light-related phenomena?

Please find more light-related phenomena not discussed in this chapter using your favorite search engine. Report back your findings.

Did you know?

Did you know there is an international light day every 16th of May recognized by the United Nations Educational, Scientific and Cultural Organization (UNESCO)? For more details, click here

"},{"location":"course/computational_light/#what-is-computational-light","title":"What is Computational Light?","text":"

Informative

Computational light is a term that brings the concepts in computational methods with the characteristics of light. In other words, wherever we can program the qualities of light, this will get us into the topics of computational light. Programming light may sound unconventional, but I invite you to consider how we program current computers. These conventional computers interpret voltage levels in an electric signal as ones and zeros. Color, \\(\\lambda\\), propagation direction, \\(\\vec{r}\\), amplitude, \\(A\\), phase, \\(\\phi\\), polarization, diffraction, and interference are all qualities that could help us program light to achieve tasks for specific applications.

"},{"location":"course/computational_light/#applications-of-computational-light","title":"Applications of Computational Light","text":"

Informative \u00b7 Media

There are enormous amounts of applications of light. Let us glance at some of the important ones to get a sense of possibilities for people studying the topics of computational light. For each topic highlighted below, please click on the box to discover more about that specific subfield of computational light.

Computer Graphics

Computer Graphics deals with generating synthetic images using computers and simulations of light. Common examples of Computer Graphics are the video games we all play and are familiar with. In today's world, you can often find Computer Graphics as a tool to simulate and synthesize scenes for developing a trending topic, artificial intelligence.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of modern Computer Graphics. Here are some people I would encourage you to explore their websites: Peter Shirley, and Morgan Mcguire.
  • Successful products. Here are a few examples of successful outcomes from the field of Computer Graphics: Roblox, NVIDIA's DLSS, Apple's Metal, OpenGL and Vulkan.
  • Did you know? The lecturer of the Computational Light Course, Kaan Ak\u015fit, is actively researching topics of Computer Graphics (e.g., Beyond blur: Real-time ventral metamers for foveated rendering4).
  • Want to learn more? Although we will cover a great deal of Computer Graphics in this course, you may want to dig deeper with a dedicated course, which you can follow online:
Computational Displays

Computational Displays topic deals with inventing next-generation display technology for the future of human-computer interaction. Common examples of emerging Computational Displays are near-eye displays such as Virtual Reality headsets and Augmented Reality Glasses. Today, we all use displays as a core component for any visual task, such as working, entertainment, education, and many more.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Displays. Here are some examples of such people; I would encourage you to explore their websites: Rafa\u0142 Mantiuk, and Andrew Maimone.
  • Successful products. Here are a few examples of successful outcomes from the field of Computational Displays: Nreal Augmented Reality glasses and Meta Quest Virtual Reality headsets.
  • Did you know? The lecturer of the Computational Light Course, Kaan Ak\u015fit, is actively researching topics of Computational Displays (e.g., Near-Eye Varifocal Augmented Reality Display using See-Through Screens 5). Kaan has made noticeable contributions to three-dimensional displays, virtual reality headsets, and augmented reality glasses.
  • Want to learn more? Although we will cover a great deal of Computational Displays in this course, you may want to dig deeper with a dedicated course, which you can follow online 6:
Computational Photography

Computational Photography topic deals with digital image capture based on optical hardware such as cameras. Common examples of emerging Computational Photography are smartphone applications such as shooting in the dark or capturing selfies. Today, we all use products of Computational Photography to capture glimpses from our daily lives and store them as memories.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Displays. Here are some examples of such people; I would encourage you to explore their websites: Diego Gutierrez and Jinwei Gu.
  • Successful products. Here are a few examples of successful outcomes from the field of Computational Displays: Google's Night Sight and Samsung Camera modes.
  • Want to learn more? Although we will cover relevant information for Computational Photography in this course, you may want to dig deeper with a dedicated course, which you can follow online:
Computational Imaging and Sensing

Computational Imaging and Sensing topic deal with imaging and sensing certain scene qualities. Common examples of Computational Imaging and Sensing can be found in the two other domains of Computational Light: Computational Astronomy and Computational Microscopy. Today, medical diagnoses of biological samples in hospitals or imaging stars and beyond or sensing vital signals are all products of Computational Imaging and Sensing.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Imaging and Sensing. Here are some examples of such people; I would encourage you to explore their websites: Laura Waller and Nick Antipa.
  • Successful products. Here are a few examples of successful outcomes from the field of Computational Imaging and Sensing: Zeiss Microscopes and Heart rate sensors on Apple's Smartwatch.
  • Did you know? The lecturer of the Computational Light Course, Kaan Ak\u015fit, is actively researching topics of Computational Imaging and Displays (e.g., Unrolled Primal-Dual Networks for Lensless Cameras 7).
  • Want to learn more? Although we will cover a great deal of Computational Imaging and Sensing in this course, you may want to dig deeper with a dedicated course, which you can follow online:
Computational Optics and Fabrication

The Computational Optics and Fabrication topic deals with designing and fabricating optical components such as lenses, mirrors, diffraction gratings, holographic optical elements, and metasurfaces. There is a little bit of Computational Optics and Fabrication in every sector of Computational Light, especially when there is a need for custom optical design.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of Computational Optics and Fabrication. Here are some examples of such people; I would encourage you to explore their websites: Jannick Rolland and Mark Pauly.
  • Did you know? The lecturer of the Computational Light Course, Kaan Ak\u015fit, is actively researching topics of Computational Optics and Fabrication (e.g., Manufacturing application-driven foveated near-eye displays 8).
  • Want to learn more? Although we will cover a great deal of Computational Imaging and Sensing in this course, you may want to dig deeper with a dedicated course, which you can follow online:
Optical Communication

Optical Communication deals with using light as a medium for telecommunication applications. Common examples of Optical Communication are the fiber cables and satellites equipped with optical links in space running our Internet. In today's world, Optical Communication runs our entire modern life by making the Internet a reality.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of modern Optical Communication. Here are some people I would encourage you to explore their websites: Harald Haas and Anna Maria Vegni.
  • Did you know? The lecturer of the Computational Light Course, Kaan Ak\u015fit, was researching topics of Optical Communication (e.g., From sound to sight: Using audio processing to enable visible light communication 9).
  • Want to learn more? Although we will cover relevant information for Optical Communication in this course, you may want to dig deeper and could start with this online video:
All-optical Machine Learning

All-optical Machine Learning deals with building neural networks and computers running solely based on light. As this is an emerging field, there are yet to be products in this field that we use in our daily lives. But this also means there are opportunities for newcomers and investors in this space.

  • Noticeable profiles. Like in any field, there are noticeable people in this field that you may want to observe their profiles to get a sense of who they are, what they achieve, or what they built for the development of All-optical Machine Learning. Here are some people I would encourage you to explore their websites: Aydogan Ozcan and Ugur Tegin.
  • Want to learn more? Although we will cover a great deal of All-optical Machine Learning in this course, you may want to dig deeper with a dedicated course, which you can follow online:
Lab work: What are the other fields and interesting profiles out there?

Please explore other relevant fields to Computational Light, and explore interesting profiles out there. Please make a list of relevant fields and interesting profiles and report your top three.

Indeed, there are more topics related to computational light than the ones highlighted here. If you are up to a challenge for the next phase of your life, you could help the field identify new opportunities with light-based sciences. In addition, there are indeed more topics, more noticeable profiles, successful product examples, and dedicated courses that focus on every one of these topics. Examples are not limited to the ones that I have provided above. Your favorite search engine is your friend to find out more in this case.

Lab work: Where do we find good resources?

Please explore software projects on GitHub and papers on Google Scholar to find out about works that are relevant to the theme of Computational Light. Please make a list of these projects and report the top three projects that you feel most exciting and interesting.

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

  1. David Keun Cheng and others. Fundamentals of engineering electromagnetics. Addison-Wesley Reading, MA, 1993.\u00a0\u21a9

  2. Chandra Roychoudhuri, Al F Kracklauer, and Kathy Creath. The nature of light: What is a photon? CRC Press, 2017.\u00a0\u21a9

  3. Eugene Hecht. Optics. Pearson Education India, 2012.\u00a0\u21a9

  4. David R Walton, Rafael Kuffner Dos Anjos, Sebastian Friston, David Swapp, Kaan Ak\u015fit, Anthony Steed, and Tobias Ritschel. Beyond blur: real-time ventral metamers for foveated rendering. ACM Transactions on Graphics, 40(4):1\u201314, 2021.\u00a0\u21a9

  5. Kaan Ak\u015fit, Ward Lopes, Jonghyun Kim, Peter Shirley, and David Luebke. Near-eye varifocal augmented reality display using see-through screens. ACM Transactions on Graphics (TOG), 36(6):1\u201313, 2017.\u00a0\u21a9

  6. Koray Kavakli, David Robert Walton, Nick Antipa, Rafa\u0142 Mantiuk, Douglas Lanman, and Kaan Ak\u015fit. Optimizing vision and visuals: lectures on cameras, displays and perception. In ACM SIGGRAPH 2022 Courses, pages 1\u201366. 2022.\u00a0\u21a9

  7. Oliver Kingshott, Nick Antipa, Emrah Bostan, and Kaan Ak\u015fit. Unrolled primal-dual networks for lensless cameras. Optics Express, 30(26):46324\u201346335, 2022.\u00a0\u21a9

  8. Kaan Ak\u015fit, Praneeth Chakravarthula, Kishore Rathinavel, Youngmo Jeong, Rachel Albert, Henry Fuchs, and David Luebke. Manufacturing application-driven foveated near-eye displays. IEEE transactions on visualization and computer graphics, 25(5):1928\u20131939, 2019.\u00a0\u21a9

  9. Stefan Schmid, Daniel Schwyn, Kaan Ak\u015fit, Giorgio Corbellini, Thomas R Gross, and Stefan Mangold. From sound to sight: using audio processing to enable visible light communication. In 2014 IEEE Globecom Workshops (GC Wkshps), 518\u2013523. IEEE, 2014.\u00a0\u21a9

"},{"location":"course/computer_generated_holography/","title":"Computer-Generated Holography","text":"Narrate section"},{"location":"course/computer_generated_holography/#computer-generated-holography","title":"Computer-Generated Holography","text":"

In this section, we introduce Computer-Generated Holography (CGH) 12 as another emerging method to simulate light. CGH offers an upgraded but more computationally expensive way to simulating light concerning the raytracing method described in the previous section. This section dives deep into CGH and will explain how CGH differs from raytracing as we go.

"},{"location":"course/computer_generated_holography/#what-is-holography","title":"What is holography?","text":"

Informative

Holography is a method in Optical sciences to represent light distribution using amplitude and phase of light. In much simpler terms, holography describes light distribution emitted from an object, scene, or illumination source over a surface by treating the light as a wave. The primary difference of holography concerning raytracing is that it accounts not only amplitude or intensity of light but also the phase of light. Unlike classical raytracing, holography also includes diffraction and interference phenomena. In raytracing, the smallest building block that defines light is a ray, whereas, in holography, the building block is a light distribution over surfaces. In other terms, while raytracing traces rays, holography deals with surface-to-surface light transfer.

Did you know this source?

There is an active repository on GitHub, where latest CGH papers relevant to display technologies are listed. Visit GitHub:bchao1/awesome-holography for more.

"},{"location":"course/computer_generated_holography/#what-is-a-hologram","title":"What is a hologram?","text":"

Informative

Hologram is either a surface or a volume that modifies the light distribution of incoming light in terms of phase and amplitude. Diffraction gratings, Holographic Optical Elements, or Metasurfaces are good examples of holograms. Within this section, we also use the term hologram as a means to describe a lightfield or a slice of a lightfield.

"},{"location":"course/computer_generated_holography/#what-is-computer-generated-holography","title":"What is Computer-Generated Holography?","text":"

Informative

It is the computerized version (discrete sampling) of holography. In other terms, whenever you can program the phase or amplitude of light, this will get us to Computer-Generated Holography.

Where can I find an extensive summary on CGH?

You may be wondering about the greater physical details of CGH. In this case, we suggest our readers watch the video below. Please watch this video for an extensive summary on CGH 3.

"},{"location":"course/computer_generated_holography/#defining-a-slice-of-a-lightfield","title":"Defining a slice of a lightfield","text":"

Informative \u00b7 Practical

CGH deals with generating optical fields that capture light from various scenes. CGH often describes these optical fields (a.k.a. lightfields, holograms) as planes. So in CGH, light travels from plane to plane, as depicted below. Roughly, CGH deals with plane to plane interaction of light, whereas raytracing is a ray or beam oriented description of light.

A rendering showing how a slice (a.k.a. lightfield, optical field, hologram) propagates from one plane to another plane.

In other words, in CGH, you define everything as a \"lightfield,\" including light sources, materials, and objects. Thus, we must first determine how to describe the mentioned lightfield in a computer. So that we can run CGH simulations effectively.

A lightfield is a planar slice in the context of CGH, as depicted in the above figure. This planar field is a pixelated 2D surface (could be represented as a matrix). The pixels in this 2D slice hold values for the amplitude of light, \\(A\\), and the phase of the light, \\(\\phi\\) at each pixel. Whereas in classical raytracing, a ray only holds the amplitude or intensity of light. With a caveat, though, raytracing could also be made to care about the phase of light. Still, it will then arrive with all the complications of raytracing, like sampling enough rays or describing scenes accurately.

Each pixel in this planar lightfield slice encapsulates the \\(A\\) and \\(\\phi\\) as \\(A cos(wt + \\phi)\\). If you recall our description of light, we explain that light is an electromagnetic phenomenon. Here, we model the oscillating electric field of light with \\(A cos(wt + \\phi)\\) shown in our previous light description. Note that if we stick to \\(A cos(wt + \\phi)\\), each time two fields intersect, we have to deal with trigonometric conversion complexities like sampled in this example:

\\[ A_0 cos(wt + \\phi_0) + A_1 cos(wt + \\phi_1), \\]

Where the indices zero and one indicate the first and second fields, and we have to identify the right trigonometric conversion to deal with this sum.

Instead of complicated trigonometric conversions, what people do in CGH is to rely on complex numbers as a proxy to these trigonometric conversions. In its proxy form, a pixel value in a field is converted into \\(A e^{-j \\phi}\\), where \\(j\\) represents a complex number (\\(\\sqrt{-1}\\)). Thus, with this new proxy representation, the same intersection problem we dealt with using sophisticated trigonometry before could be turned into something as simple as \\(A_0 A_1 e^{-j(\\phi_0 +\\phi_1)}\\).

In the above summation of two fields, the resulting field follows an exact sum of the two collided fields. On the other hand, in raytracing, often, when a ray intersects with another ray, it will be left unchanged and continue its path. However, in the case of lightfields, they form a new field. This feature is called interference of light, which is not introduced in raytracing, and often raytracing omits this feature. As you can tell from also the summation, two fields could enhance the resulting field (constructive interference) by converging to a brighter intensity, or these two fields could cancel out each other (destructive interference) and lead to the absence of light --total darkness--.

There are various examples of interference in nature. For example, the blue color of a butterfly wing results from interference, as biology typically does not produce blue-colored pigments in nature. More examples of light interference from daily lives are provided in the figure below.

Two photographs showin some examples of light interference: (left) thin oil film creates rainbow interference patterns (CC BY-SA 2.5 by Wikipedia user John) and a soup bubble interference with light and creates vivid reflections (CC BY-SA 3.0 by Wikipedia user Brocken Inaglory).

We have established an easy way to describe a field with a proxy complex number form. This way, we avoided complicated trigonometric conversions. Let us look into how we use that in an actual simulation. Firstly, we can define two separate matrices to represent a field using real numbers:

import torch\n\namplitude = torch.tensor(100, 100, dtype = torch.float64)\nphase = torch.tensor(100, 100, dtype = torch.float64)\n

In this above example, we define two matrices with 100 x 100 dimensions. Each matrix holds floating point numbers, and they are real numbers. To convert the amplitude and phase into a field, we must define the field as suggested in our previous description. Instead of going through the same mathematical process for every piece of our future codes, we can rely on a utility function in odak to create fields consistently and coherently across all our future developments. The utility function we will review is odak.learn.wave.generate_complex_field():

Here, we provide visual results from this piece of code as below:

odak.learn.wave.generate_complex_field

Definition to generate a complex field with a given amplitude and phase.

Parameters:

  • amplitude \u2013
                Amplitude of the field.\n            The expected size is [m x n] or [1 x m x n].\n
  • phase \u2013
                Phase of the field.\n            The expected size is [m x n] or [1 x m x n].\n

Returns:

  • field ( ndarray ) \u2013

    Complex field. Depending on the input, the expected size is [m x n] or [1 x m x n].

Source code in odak/learn/wave/util.py
def generate_complex_field(amplitude, phase):\n    \"\"\"\n    Definition to generate a complex field with a given amplitude and phase.\n\n    Parameters\n    ----------\n    amplitude         : torch.tensor\n                        Amplitude of the field.\n                        The expected size is [m x n] or [1 x m x n].\n    phase             : torch.tensor\n                        Phase of the field.\n                        The expected size is [m x n] or [1 x m x n].\n\n    Returns\n    -------\n    field             : ndarray\n                        Complex field.\n                        Depending on the input, the expected size is [m x n] or [1 x m x n].\n    \"\"\"\n    field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)\n    return field\n

Let us use this utility function to expand our previous code snippet and show how we can generate a complex field using that:

import torch\nimport odak # (1)\n\namplitude = torch.tensor(100, 100, dtype = torch.float64)\nphase = torch.tensor(100, 100, dtype = torch.float64)\nfield = odak.learn.wave.generate_complex_field(amplitude, phase) # (2)\n
  1. Adding odak to our imports.
  2. Generating a field using odak.learn.wave.generate_complex_field.
"},{"location":"course/computer_generated_holography/#propagating-a-field-in-free-space","title":"Propagating a field in free space","text":"

Informative \u00b7 Practical

The next question we have to ask is related to the field we generated in our previous example. In raytracing, we propagate rays in space, whereas in CGH, we propagate a field described over a surface onto another target surface. So we need a transfer function that projects our field on another target surface. That is the point where free space beam propagation comes into play. As the name implies, free space beam propagation deals with propagating light in free space from one surface to another. This entire process of propagation is also referred to as light transport in the domains of Computer Graphics. In the rest of this section, we will explore means to simulate beam propagation on a computer.

A good news for Matlab fans!

We will indeed use odak to explore beam propagation. However, there is also a book in the literature, [Numerical simulation of optical wave propagation: With examples in MATLAB by Jason D. Schmidt](https://www.spiedigitallibrary.org/ebooks/PM/Numerical-Simulation-of-Optical-Wave-Propagation-with-Examples-in-MATLAB/eISBN-9780819483270/10.1117/3.866274?SSO=1)4, that provides a crash course on beam propagation using MATLAB.

As we revisit the field we generated in the previous subsection, we remember that our field is a pixelated 2D surface. Each pixel in our fields, either a hologram or image plane, typically has a small size of a few micrometers (e.g., \\(8 \\mu m\\)). How light travels from each one of these pixels on one surface to pixels on another is conceptually depicted as a figure at the beginning of this section (green wolf image with two planes). We will name that figure's first plane on the left as the hologram plane and the second as the image plane. In a nutshell, the contribution of a pixel on a hologram plane could be calculated by drawing rays to every pixel on the image plane. We draw rays from a point to a plane because in wave theory --what CGH follows--, light can diffract (a small aperture creating spherical waves as Huygens suggested). Each ray will have a certain distance, thus causing various delays in phase \\(\\phi\\). As long as the distance between planes is large enough, each ray will maintain an electric field that is in the same direction as the others (same polarization), thus able to interfere with other rays emerging from other pixels in a hologram plane. This simplified description oversimplifies solving the Maxwell equations in electromagnetics.

A simplified result of solving Maxwell's equation is commonly described using Rayleigh-Sommerfeld diffraction integrals. For more on Rayleigh-Sommerfeld, consult Heurtley, J. C. (1973). Scalar Rayleigh\u2013Sommerfeld and Kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8), 1003-1008. 5. The first solution of the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle, is expressed as follows:

\\[ u(x,y)=\\frac{1}{j\\lambda} \\int\\!\\!\\!\\!\\int u_0(x,y)\\frac{e^{jkr}}{r}cos(\\theta)dxdy, \\]

where the field at a target image plane, \\(u(x,y)\\), is calculated by integrating over every point of the hologram's area, \\(u_0(x,y)\\). Note that, for the above equation, \\(r\\) represents the optical path between a selected point over a hologram and a selected point in the image plane, theta represents the angle between these two points, k represents the wavenumber (\\(\\frac{2\\pi}{\\lambda}\\)) and \\(\\lambda\\) represents the wavelength of light. In this described light transport model, optical fields, \\(u_0(x,y)\\) and \\(u(x,y)\\), are represented with a complex value,

\\[ u_0(x,y)=A(x,y)e^{j\\phi(x,y)}, \\]

where \\(A\\) represents the spatial distribution of amplitude and \\(\\phi\\) represents the spatial distribution of phase across a hologram plane. The described holographic light transport model is often simplified into a single convolution with a fixed spatially invariant complex kernel, \\(h(x,y)\\) 6.

\\[ u(x,y)=u_0(x,y) * h(x,y) =\\mathcal{F}^{-1}(\\mathcal{F}(u_0(x,y)) \\mathcal{F}(h(x,y))). \\]

There are multiple variants of this simplified approach:

  • Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673. 7,
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics letters 45.6 (2020): 1543-1546. 8,
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419. 9.

In many cases, people choose to use the most common form of \\(h(x, y)\\) described as

\\[ h(x,y)=\\frac{e^{jkz}}{j\\lambda z} e^{\\frac{jk}{2z} (x^2+y^2)}, \\]

where z represents the distance between a hologram plane and a target image plane. Before, we introduce you how to use existing beam propagation in our library, let us dive deep in compiling a beam propagation code following the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle. In the rest of this script, I will walk you through the below code:

test_diffraction_integral.py
import sys\nimport odak # (1)\nimport torch\nfrom tqdm import tqdm\n\n\ndef main(): # (2)\n    length = [7e-6, 7e-6] # (3)\n    for fresnel_id, fresnel_number in enumerate(range(99)): # (4)\n        fresnel_number += 1\n        propagate(\n                  fresnel_number = fresnel_number,\n                  length = [length[0] + 1. / fresnel_number * 8e-6, length[1] + 1. / fresnel_number * 8e-6]\n                 )\n\n\ndef propagate(\n              wavelength = 532e-9, # (6)\n              pixel_pitch = 3.74e-6, # (7)\n              length = [15e-6, 15e-6],\n              image_samples = [2, 2], # Replace it with 1000 by 1000 (8)\n              aperture_samples = [2, 2], # Replace it with 1000 by 1000 (9)\n              device = torch.device('cpu'),\n              output_directory = 'test_output', \n              fresnel_number = 4,\n              save_flag = False\n             ): # (5)\n    distance = pixel_pitch ** 2 / wavelength / fresnel_number\n    distance = torch.as_tensor(distance, device = device)\n    k = odak.learn.wave.wavenumber(wavelength)\n    x = torch.linspace(- length[0] / 2, length[0] / 2, image_samples[0], device = device)\n    y = torch.linspace(- length[1] / 2, length[1] / 2, image_samples[1], device = device)\n    X, Y = torch.meshgrid(x, y, indexing = 'ij') # (10)\n    wxs = torch.linspace(- pixel_pitch / 2., pixel_pitch / 2., aperture_samples[0], device = device)\n    wys = torch.linspace(- pixel_pitch / 2., pixel_pitch / 2., aperture_samples[1], device = device) # (11)\n    h  = torch.zeros(image_samples[0], image_samples[1], dtype = torch.complex64, device = device)\n    for wx in tqdm(wxs):\n        for wy in wys:\n            h += huygens_fresnel_principle(wx, wy, X, Y, distance, k, wavelength) # (12)\n    h = h * pixel_pitch ** 2 / aperture_samples[0] / aperture_samples[1] # (13) \n\n    if save_flag:\n        save_results(h, output_directory, fresnel_number, length, pixel_pitch, distance, image_samples, device) # (14)\n    return True\n\n\ndef huygens_fresnel_principle(x, y, X, Y, z, k, wavelength): # (12)\n    r = torch.sqrt((X - x) ** 2 + (Y - y) ** 2 + z ** 2)\n    h = torch.exp(1j * k * r) * z / r ** 2 * (1. / (2 * odak.pi * r) + 1. / (1j * wavelength))\n    return h\n\n\ndef save_results(h, output_directory, fresnel_number, length, pixel_pitch, distance, image_samples, device):\n    from matplotlib import pyplot as plt\n    odak.tools.check_directory(output_directory)\n    output_intensity = odak.learn.wave.calculate_amplitude(h) ** 2\n    odak.learn.tools.save_image(\n                                '{}/diffraction_output_intensity_fresnel_number_{:02d}.png'.format(output_directory, int(fresnel_number)),\n                                output_intensity,\n                                cmin = 0.,\n                                cmax = output_intensity.max()\n                               )\n    cross_section_1d = output_intensity[output_intensity.shape[0] // 2]\n    lengths = torch.linspace(- length[0] * 10 ** 6 / 2., length[0] * 10 ** 6 / 2., image_samples[0], device = device)\n    plt.figure()\n    plt.plot(lengths.detach().cpu().numpy(), cross_section_1d.detach().cpu().numpy())\n    plt.xlabel('length (um)')\n    plt.figtext(\n                0.14,\n                0.9, \n                r'Fresnel Number: {:02d}, Pixel pitch: {:.2f} um, Distance: {:.2f} um'.format(fresnel_number, pixel_pitch * 10 ** 6, distance * 10 ** 6),\n                fontsize = 11\n               )\n    plt.savefig('{}/diffraction_1d_output_intensity_fresnel_number_{:02d}.png'.format(output_directory, int(fresnel_number)))\n    plt.cla()\n    plt.clf()\n    plt.close()\n\n\nif __name__ == '__main__':\n    sys.exit(main())\n
  1. Importing relevant libraries
  2. This is our main routine.
  3. Length of the final image plane along X and Y axes.
  4. Fresnel number is an arbitrary number that helps to get a sense if the optical configuration could be considered as a Fresnel (near field) or Fraunhofer regions.
  5. Propagating light with the given configuration.
  6. Wavelength of light.
  7. Square aperture length of a single pixel in the simulation. This is where light diffracts from.
  8. Number of pixels in the image plane along X and Y axes.
  9. Number of point light sources used to represent a single pixel's square aperture.
  10. Sample point locations along X and Y axes at the image plane.
  11. Sample point locations along X and Y axes at the aperture plane.
  12. For each, virtual point light source defined inside the aperture, we simulate the light as if divergind point light source.
  13. Normalize with the number of samples (trapezoid integration).
  14. Rest of this code is for logistics for saving images.

We start the implementation by importing necessary libraries such as odak or torch. The first function, def main, sets the length of our image plane, where we will observe the diffraction pattern. As we set the size of our image plane, we also set a arbitrary number called Fresnel Number,

\\[ n_F = \\frac{w^2}{\\lambda z}, \\]

where \\(z\\) is the propagation distance, \\(w\\) is the side length of an aperture diffracting light like a pixel's square aperture -- this is often the pixel pitch -- and \\(\\lambda\\) is the wavelength of the light. This number helps us to get an idea if the set optical configuration falls under a certain regime like Fresnel or Fraunhofer. Fresnel number also provides a practical ease related to comparing solutions. Regardless of the optical configuration, a result with a specific Fresnel number will follow a similar pattern with different optical configuration. Thus, providing a way to verify your solutions. In the next step, we call the light propagation function, def propagate. In the beginning of this function, we set the optical configuration. For instance, we set pixel_pitch, this is the side length of a square aperture that the light will diffract from. Inside the def propagate function, we reset the distance such that it follows the input Fresnel Number and wavelength. We define the locations of the samples across X and Y axes that will represent points to calculate on the image plane, x and y. Than, we define the locations of the samples across X and Y axes that will represent the point light source locations inside the aperture, wxs and wys, which in this case a square aperture that represents a single pixel and its sidelength is provided by pixel_pitch. The nested for loop goes over the wxs and wys. Each time, we choose a point from the aperture, we propagate a spherical wave from that point using def huygens_fresnel_principle. Note that we accumulate the effect of each spherical wave on a variable called h. This is diffraction pattern in complex form from our square aperture, and we also normalize it using pixel_pitch and aperture_samples. Here, we provide visual results from this piece of code as below:

Saved 1D intensities on image plane representing diffraction patterns for various Fresnel numbers. These patterns are generated by using \"test/test_diffraction_integral.py\".

Saved 2D intensities on image plane representing diffraction patterns for various Fresnel numbers. These patterns are generated by using \"test/test_diffraction_integral.py\".

Note that beam propagation can also be learned for physical setups to avoid imperfections in a setup and to improve the image quality at an image plane:

  • Peng, Yifan, et al. \"Neural holography with camera-in-the-loop training.\" ACM Transactions on Graphics (TOG) 39.6 (2020): 1-14. 10,
  • Chakravarthula, Praneeth, et al. \"Learned hardware-in-the-loop phase retrieval for holographic near-eye displays.\" ACM Transactions on Graphics (TOG) 39.6 (2020): 1-18. 11,
  • Kavakl\u0131, Koray, Hakan Urey, and Kaan Ak\u015fit. \"Learned holographic light transport.\" Applied Optics (2021). 12.

The above descriptions establish a mathematical understanding of beam propagation. Let us examine the implementation of a beam propagation method called Bandlimited Angular Spectrum by reviewing these two utility functions from odak:

odak.learn.wave.get_band_limited_angular_spectrum_kernel odak.learn.wave.band_limited_angular_spectrum odak.learn.wave.propagate_beam odak.learn.wave.wavenumber

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n

Returns:

  • H ( complex64 ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(\n                                             nu,\n                                             nv,\n                                             dx = 8e-6,\n                                             wavelength = 515e-9,\n                                             distance = 0.,\n                                             device = torch.device('cpu')\n                                            ):\n    \"\"\"\n    Helper function for odak.learn.wave.band_limited_angular_spectrum.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n\n\n    Returns\n    -------\n    H                  : torch.complex64\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    x = dx * float(nu)\n    y = dx * float(nv)\n    fx = torch.linspace(\n                        -1 / (2 * dx) + 0.5 / (2 * x),\n                         1 / (2 * dx) - 0.5 / (2 * x),\n                         nu,\n                         dtype = torch.float32,\n                         device = device\n                        )\n    fy = torch.linspace(\n                        -1 / (2 * dx) + 0.5 / (2 * y),\n                        1 / (2 * dx) - 0.5 / (2 * y),\n                        nv,\n                        dtype = torch.float32,\n                        device = device\n                       )\n    FY, FX = torch.meshgrid(fx, fy, indexing='ij')\n    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))\n    distance = torch.tensor([distance], device = device)\n    H_exp = torch.mul(HH_exp, distance)\n    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength\n    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength\n    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()\n    H = generate_complex_field(H_filter, H_exp)\n    return H\n

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field \u2013
               A complex field.\n           The expected size is [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(\n                                  field,\n                                  k,\n                                  distance,\n                                  dx,\n                                  wavelength,\n                                  zero_padding = False,\n                                  aperture = 1.\n                                 ):\n    \"\"\"\n    A definition to calculate bandlimited angular spectrum based beam propagation. For more \n    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673`.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       A complex field.\n                       The expected size is [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field [m x n].\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Bandlimited Angular Spectrum',\n                               device = field.device\n                              )\n    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n

Definitions for various beam propagation methods mostly in accordence with \"Computational Fourier Optics\" by David Vuelz.

Parameters:

  • field \u2013
               Complex field [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) \u2013
               Type of the propagation.\n           The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.\n
  • kernel \u2013
               Custom complex kernel.\n
  • zero_padding \u2013
               Zero padding the input field if the first item in the list set True.\n           Zero padding in the Fourier domain if the second item in the list set to True.\n           Cropping the result with half resolution if the third item in the list is set to true.\n           Note that in Fraunhofer propagation, setting the second item True or False will have no effect.\n
  • aperture \u2013
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.\n           If provided as a floating point 1, there will be no aperture in Fourier domain.\n
  • scale \u2013
               Resolution factor to scale generated kernel.\n
  • samples \u2013
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n

Returns:

  • result ( complex ) \u2013

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(\n                   field,\n                   k,\n                   distance,\n                   dx,\n                   wavelength,\n                   propagation_type='Bandlimited Angular Spectrum',\n                   kernel = None,\n                   zero_padding = [True, False, True],\n                   aperture = 1.,\n                   scale = 1,\n                   samples = [20, 20, 5, 5]\n                  ):\n    \"\"\"\n    Definitions for various beam propagation methods mostly in accordence with \"Computational Fourier Optics\" by David Vuelz.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    propagation_type : str\n                       Type of the propagation.\n                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.\n    kernel           : torch.complex\n                       Custom complex kernel.\n    zero_padding     : list\n                       Zero padding the input field if the first item in the list set True.\n                       Zero padding in the Fourier domain if the second item in the list set to True.\n                       Cropping the result with half resolution if the third item in the list is set to true.\n                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.\n    aperture         : torch.tensor\n                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.\n                       If provided as a floating point 1, there will be no aperture in Fourier domain.\n    scale            : int\n                       Resolution factor to scale generated kernel.\n    samples          : list\n                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field [m x n].\n    \"\"\"\n    if zero_padding[0]:\n        field = zero_pad(field)\n    if propagation_type == 'Angular Spectrum':\n        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Bandlimited Angular Spectrum':\n        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Impulse Response Fresnel':\n        result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)\n    elif propagation_type == 'Seperable Impulse Response Fresnel':\n        result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)\n    elif propagation_type == 'Transfer Function Fresnel':\n        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'custom':\n        result = custom(field, kernel, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Fraunhofer':\n        result = fraunhofer(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Incoherent Angular Spectrum':\n        result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    else:\n        logging.warning('Propagation type not recognized')\n        assert True == False\n    if zero_padding[2]:\n        result = crop_center(result)\n    return result\n

Definition for calculating the wavenumber of a plane wave.

Parameters:

  • wavelength \u2013
           Wavelength of a wave in mm.\n

Returns:

  • k ( float ) \u2013

    Wave number for a given wavelength.

Source code in odak/learn/wave/util.py
def wavenumber(wavelength):\n    \"\"\"\n    Definition for calculating the wavenumber of a plane wave.\n\n    Parameters\n    ----------\n    wavelength   : float\n                   Wavelength of a wave in mm.\n\n    Returns\n    -------\n    k            : float\n                   Wave number for a given wavelength.\n    \"\"\"\n    k = 2 * np.pi / wavelength\n    return k\n

Let us see how we can use the given beam propagation function with an example:

test_learn_wave_propagate_beam.py
import sys\nimport os\nimport odak\nimport numpy as np\nimport torch\n\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    wavelength = 532e-9 # (1)\n    pixel_pitch = 8e-6 # (2)\n    distance = 0.5e-2 # (3)\n    propagation_types = ['Angular Spectrum', 'Bandlimited Angular Spectrum', 'Transfer Function Fresnel'] # (4)\n    k = odak.learn.wave.wavenumber(wavelength) # (5)\n\n\n    amplitude = torch.zeros(500, 500)\n    amplitude[200:300, 200:300 ] = 1. # (5)\n    phase = torch.randn_like(amplitude) * 2 * odak.pi # (6)\n    hologram = odak.learn.wave.generate_complex_field(amplitude, phase) # (7)\n\n    for propagation_type in propagation_types:\n        image_plane = odak.learn.wave.propagate_beam(\n                                                     hologram,\n                                                     k,\n                                                     distance,\n                                                     pixel_pitch,\n                                                     wavelength,\n                                                     propagation_type,\n                                                     zero_padding = [True, False, True] # (8)\n                                                    ) # (9)\n\n        image_intensity = odak.learn.wave.calculate_amplitude(image_plane) ** 2 # (10)\n        hologram_intensity = amplitude ** 2\n\n        odak.learn.tools.save_image(\n                                    '{}/image_intensity_{}.png'.format(output_directory, propagation_type.replace(' ', '_')), \n                                    image_intensity, \n                                    cmin = 0., \n                                    cmax = image_intensity.max()\n                                ) # (11)\n        odak.learn.tools.save_image(\n                                    '{}/hologram_intensity_{}.png'.format(output_directory, propagation_type.replace(' ', '_')), \n                                    hologram_intensity, \n                                    cmin = 0., \n                                    cmax = 1.\n                                ) # (12)\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test()) \n
  1. Setting the wavelength of light in meters. We use 532 nm (green light) in this example.
  2. Setting the physical size of a single pixel in our simulation. We use \\(6 \\mu m\\) pixel size (width and height are both \\(6 \\mu m\\).)
  3. Setting the distance between two planes, hologram and image plane. We set it as half a centimeterhere.
  4. We set the propagation type to Bandlimited Angular Spectrum.
  5. Here, we calculate a value named wavenumber, which we introduced while we were talking about the beam propagation functions.
  6. Here, we assume that there is a rectangular light at the center of our hologram.
  7. Here, we generate the field by combining amplitude and phase.
  8. Here, we zeropad and crop our field before and after the beam propagation to make sure that there is no aliasing in our results (see Nyquist criterion).
  9. We propagate the beam using the values and field provided.
  10. We calculate the final intensity on our image plane. Remember that human eyes can see intensity but not amplitude or phase of light. Intensity of light is a square of its amplitude.
  11. We save image plane intensity to an image file.
  12. For comparison, we also save the hologram intensity to an image file so that we can observe how our light transformed from one plane to another.

Let us also take a look at the saved images as a result of the above sample code:

Saved intensities before (left_ and after (right) beam propagation (hologram and image plane intensities). This result is generated using \"test/test_learn_beam_propagation.py\". Challenge: Light transport on Arbitrary Surfaces

We know that we can propagate a hologram to any image plane at any distance. This propagation is a plane-to-plane interaction. However, there may be cases where a simulation involves finding light distribution over an arbitrary surface. Conventionally, this could be achieved by propagating the hologram to multiple different planes and picking the results from each plane on the surface of that arbitrary surface. We challenge our readers to code the mentioned baseline (multiple planes for arbitrary surfaces) and ask them to develop a beam propagation that is less computationally expensive and works for arbitrary surfaces (e.g., tilted planes or arbitrary shapes). This development could either rely on classical approaches or involve learning-based methods. The resultant method could be part of odak.learn.wave submodule as a new class odak.learn.wave.propagate_arbitrary. In addition, a unit test test/test_learn_propagate_arbitrary.py has to adopt this new class. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for arbitrary surfaces in docs/notes/beam_propagation_arbitrary_surfaces.md.

"},{"location":"course/computer_generated_holography/#optimizing-holograms","title":"Optimizing holograms","text":"

Informative \u00b7 Practical

In the previous subsection, we propagate an input field (a.k.a. lightfield, hologram) to another plane called the image plane. We can store any scene or object as a field on such planes. Thus, we have learned that we can have a plane (hologram) to capture or display a slice of a lightfield for any given scene or object. After all this introduction, it is also safe to say, regardless of hardware, holograms are the natural way to represent three-dimensional scenes, objects, and data!

Holograms come in many forms. We can broadly classify holograms as analog and digital. Analog holograms are physically tailored structures. They are typically a result of manufacturing engineered surfaces (micron or nanoscale structures). Some examples of analog holograms include diffractive optical elements 13, holographic optical elements 14, and metasurfaces 15. Here, we show an example of an analog hologram that gives us a slice of a lightfield, and we can observe the scene this way from various perspectives:

A video showing analog hologram example from Zebra Imaging -- ZScape.

Digital holograms are the ones that are dynamic and generated using programmable versions of analog holograms. Typically, the tiniest fraction of digital holograms is a pixel that either manipulates the phase or amplitude of light. In our laboratory, we build holographic displays 1612, a programmable device to display holograms. The components used in such a display are illustrated in the following rendering and contain a Spatial Light Modulator (SLM) that could display programmable holograms. Note that the SLM in this specific hardware can only manipulate phase of an incoming light.

A rendering showing a standard holographic display hardware.

We can display holograms that generate images to fill a three-dimensional volume using the above hardware. We know that they are three-dimensional from the fact that we can focus on different parts of the images by changing the focus of our camera (closely observing the camera's location in the above figure). Let us look into a sample result to see what these three-dimensional images look like as we focus on different scene parts.

A series of photographs at various focuses capturing images from our computer-generated holograms.

Let us look into how we can optimize a hologram for our holographic display by visiting the below example:

test_learn_wave_stochastic_gradient_descent.py
import sys\nimport odak\nimport torch\n\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    device = torch.device('cpu') # (1)\n    target = odak.learn.tools.load_image('./test/data/usaf1951.png', normalizeby = 255., torch_style = True)[1] # (4)\n    hologram, reconstruction = odak.learn.wave.stochastic_gradient_descent(\n                                                                           target,\n                                                                           wavelength = 532e-9,\n                                                                           distance = 20e-2,\n                                                                           pixel_pitch = 8e-6,\n                                                                           propagation_type = 'Bandlimited Angular Spectrum',\n                                                                           n_iteration = 50,\n                                                                           learning_rate = 0.1\n                                                                          ) # (2)\n    odak.learn.tools.save_image(\n                                '{}/phase.png'.format(output_directory), \n                                odak.learn.wave.calculate_phase(hologram) % (2 * odak.pi), \n                                cmin = 0., \n                                cmax = 2 * odak.pi\n                               ) # (3)\n    odak.learn.tools.save_image('{}/sgd_target.png'.format(output_directory), target, cmin = 0., cmax = 1.)\n    odak.learn.tools.save_image(\n                                '{}/sgd_reconstruction.png'.format(output_directory), \n                                odak.learn.wave.calculate_amplitude(reconstruction) ** 2, \n                                cmin = 0., \n                                cmax = 1.\n                               )\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
  1. Replace cpu with cuda if you have a NVIDIA GPU with enough memory or AMD GPU with enough memory and ROCm support.
  2. We will provide the details of this optimization function in the next part.
  3. Saving the phase-only hologram. Note that a phase-only hologram is between zero and two pi.
  4. Loading an image from a file with 1920 by 1080 resolution and using green channel.

The above sample optimization script uses a function called odak.learn.wave.stochastic_gradient_descent. This function sits at the center of this optimization, and we have to understand what it entails by closely observing its inputs, outputs, and source code. Let us review the function.

odak.learn.wave.stochastic_gradient_descent

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target \u2013
                        Target field amplitude [m x n].\n                    Keep the target values between zero and one.\n
  • wavelength \u2013
                        Set if the converted array requires gradient.\n
  • distance \u2013
                        Hologram plane distance wrt SLM plane.\n
  • pixel_pitch \u2013
                        SLM pixel pitch in meters.\n
  • propagation_type \u2013
                        Type of the propagation (see odak.learn.wave.propagate_beam()).\n
  • n_iteration \u2013
                        Number of iteration.\n
  • loss_function \u2013
                        If none it is set to be l2 loss.\n
  • learning_rate \u2013
                        Learning rate.\n

Returns:

  • hologram ( Tensor ) \u2013

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) \u2013

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):\n    \"\"\"\n    Definition to generate phase and reconstruction from target image via stochastic gradient descent.\n\n    Parameters\n    ----------\n    target                    : torch.Tensor\n                                Target field amplitude [m x n].\n                                Keep the target values between zero and one.\n    wavelength                : double\n                                Set if the converted array requires gradient.\n    distance                  : double\n                                Hologram plane distance wrt SLM plane.\n    pixel_pitch               : float\n                                SLM pixel pitch in meters.\n    propagation_type          : str\n                                Type of the propagation (see odak.learn.wave.propagate_beam()).\n    n_iteration:              : int\n                                Number of iteration.\n    loss_function:            : function\n                                If none it is set to be l2 loss.\n    learning_rate             : float\n                                Learning rate.\n\n    Returns\n    -------\n    hologram                  : torch.Tensor\n                                Phase only hologram as torch array\n\n    reconstruction_intensity  : torch.Tensor\n                                Reconstruction as torch array\n\n    \"\"\"\n    phase = torch.randn_like(target, requires_grad = True)\n    k = wavenumber(wavelength)\n    optimizer = torch.optim.Adam([phase], lr = learning_rate)\n    if type(loss_function) == type(None):\n        loss_function = torch.nn.MSELoss()\n    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)\n    for i in t:\n        optimizer.zero_grad()\n        hologram = generate_complex_field(1., phase)\n        reconstruction = propagate_beam(\n                                        hologram, \n                                        k, \n                                        distance, \n                                        pixel_pitch, \n                                        wavelength, \n                                        propagation_type, \n                                        zero_padding = [True, False, True]\n                                       )\n        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2\n        loss = loss_function(reconstruction_intensity, target)\n        description = \"Loss:{:.4f}\".format(loss.item())\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description(description)\n    logging.warning(description)\n    torch.no_grad()\n    hologram = generate_complex_field(1., phase)\n    reconstruction = propagate_beam(\n                                    hologram, \n                                    k, \n                                    distance, \n                                    pixel_pitch, \n                                    wavelength, \n                                    propagation_type, \n                                    zero_padding = [True, False, True]\n                                   )\n    return hologram, reconstruction\n

Let us also examine the optimized hologram and the image that the hologram reconstructed at the image plane.

Optimized phase-only hologram. Generated using \"test/test_learn_wave_stochastic_gradient_descent.py\".

Optimized phase-only hologram reconstructed at the image plane, generated using \"test/test_learn_wave_stochastic_gradient_descent.py\". Challenge: Non-iterative Learned Hologram Calculation

We provided an overview of optimizing holograms using iterative methods. Iterative methods are computationally expensive and unsuitable for real-time hologram generation. We challenge our readers to derive a learned hologram generation method for multiplane images (not single-plane like in our example). This development could either rely on classical convolutional neural networks or blend with physical priors explained in this section. The resultant method could be part of odak.learn.wave submodule as a new class odak.learn.wave.learned_hologram. In addition, a unit test test/test_learn_hologram.py has to adopt this new class. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for arbitrary surfaces in docs/notes/learned_hologram_generation.md.

"},{"location":"course/computer_generated_holography/#simulating-a-standard-holographic-display","title":"Simulating a standard holographic display","text":"

Informative \u00b7 Practical

We optimized holograms for a holographic display in the previous section. However, the beam propagation distance we used in our optimization example was large. If we were to run the same optimization for a shorter propagation distance, say not cms but mms, we would not get a decent solution. Because in an actual holographic display, there is an aperture that helps to filter out some of the light. The previous section contained an optical layout rendering of a holographic display, where this aperture is also depicted. As depicted in the rendering located in the previous section, this aperture is located between a two lens system, which is also known as 4F imaging system.

Did you know?

4F imaging system can take a Fourier transform of an input field by using physics but not computers. For more details, please review these course notes from MIT.

Let us review the class dedicated to accurately simulating a holographic display and its functions:

odak.learn.wave.propagator.reconstruct odak.learn.wave.propgator.__call__

Internal function to reconstruct a given hologram.

Parameters:

  • hologram_phases \u2013
                         Hologram phases [ch x m x n].\n
  • amplitude \u2013
                         Amplitude profiles for each color primary [ch x m x n]\n
  • no_grad \u2013
                         If set True, uses torch.no_grad in reconstruction.\n
  • get_complex \u2013
                         If set True, reconstructor returns the complex field but not the intensities.\n

Returns:

  • reconstructions ( tensor ) \u2013

    Reconstructed frames.

Source code in odak/learn/wave/propagators.py
def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):\n    \"\"\"\n    Internal function to reconstruct a given hologram.\n\n\n    Parameters\n    ----------\n    hologram_phases            : torch.tensor\n                                 Hologram phases [ch x m x n].\n    amplitude                  : torch.tensor\n                                 Amplitude profiles for each color primary [ch x m x n]\n    no_grad                    : bool\n                                 If set True, uses torch.no_grad in reconstruction.\n    get_complex                : bool\n                                 If set True, reconstructor returns the complex field but not the intensities.\n\n    Returns\n    -------\n    reconstructions            : torch.tensor\n                                 Reconstructed frames.\n    \"\"\"\n    if no_grad:\n        torch.no_grad()\n    if len(hologram_phases.shape) > 3:\n        hologram_phases = hologram_phases.squeeze(0)\n    if get_complex == True:\n        reconstruction_type = torch.complex64\n    else:\n        reconstruction_type = torch.float32\n    reconstructions = torch.zeros(\n                                  self.number_of_frames,\n                                  self.number_of_depth_layers,\n                                  self.number_of_channels,\n                                  self.resolution[0] * self.resolution_factor,\n                                  self.resolution[1] * self.resolution_factor,\n                                  dtype = reconstruction_type,\n                                  device = self.device\n                                 )\n    if isinstance(amplitude, type(None)):\n        amplitude = torch.zeros(\n                                self.number_of_channels,\n                                self.resolution[0] * self.resolution_factor,\n                                self.resolution[1] * self.resolution_factor,\n                                device = self.device\n                               )\n        amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.\n    if self.resolution_factor != 1:\n        hologram_phases_scaled = torch.zeros_like(amplitude)\n        hologram_phases_scaled[\n                               :,\n                               ::self.resolution_factor,\n                               ::self.resolution_factor\n                              ] = hologram_phases\n    else:\n        hologram_phases_scaled = hologram_phases\n    for frame_id in range(self.number_of_frames):\n        for depth_id in range(self.number_of_depth_layers):\n            for channel_id in range(self.number_of_channels):\n                laser_power = self.get_laser_powers()[frame_id][channel_id]\n                phase = hologram_phases_scaled[frame_id]\n                hologram = generate_complex_field(\n                                                  laser_power * amplitude[channel_id],\n                                                  phase * self.phase_scale[channel_id]\n                                                 )\n                reconstruction_field = self.__call__(hologram, channel_id, depth_id)\n                if get_complex == True:\n                    result = reconstruction_field\n                else:\n                    result = calculate_amplitude(reconstruction_field) ** 2\n                reconstructions[\n                                frame_id,\n                                depth_id,\n                                channel_id\n                               ] = result.detach().clone()\n    return reconstructions\n

Function that represents the forward model in hologram optimization.

Parameters:

  • input_field \u2013
                  Input complex input field.\n
  • channel_id \u2013
                  Identifying the color primary to be used.\n
  • depth_id \u2013
                  Identifying the depth layer to be used.\n

Returns:

  • output_field ( tensor ) \u2013

    Propagated output complex field.

Source code in odak/learn/wave/propagators.py
def __call__(self, input_field, channel_id, depth_id):\n    \"\"\"\n    Function that represents the forward model in hologram optimization.\n\n    Parameters\n    ----------\n    input_field         : torch.tensor\n                          Input complex input field.\n    channel_id          : int\n                          Identifying the color primary to be used.\n    depth_id            : int\n                          Identifying the depth layer to be used.\n\n    Returns\n    -------\n    output_field        : torch.tensor\n                          Propagated output complex field.\n    \"\"\"\n    distance = self.distances[depth_id]\n    if not self.generated_kernels[depth_id, channel_id]:\n        if self.propagator_type == 'forward':\n            H = get_propagation_kernel(\n                                       nu = self.resolution[0] * 2,\n                                       nv = self.resolution[1] * 2,\n                                       dx = self.pixel_pitch,\n                                       wavelength = self.wavelengths[channel_id],\n                                       distance = distance,\n                                       device = self.device,\n                                       propagation_type = self.propagation_type,\n                                       samples = self.aperture_samples,\n                                       scale = self.resolution_factor\n                                      )\n        elif self.propagator_type == 'back and forth':\n            H_forward = get_propagation_kernel(\n                                               nu = self.resolution[0] * 2,\n                                               nv = self.resolution[1] * 2,\n                                               dx = self.pixel_pitch,\n                                               wavelength = self.wavelengths[channel_id],\n                                               distance = self.zero_mode_distance,\n                                               device = self.device,\n                                               propagation_type = self.propagation_type,\n                                               samples = self.aperture_samples,\n                                               scale = self.resolution_factor\n                                              )\n            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)\n            H_back = get_propagation_kernel(\n                                            nu = self.resolution[0] * 2,\n                                            nv = self.resolution[1] * 2,\n                                            dx = self.pixel_pitch,\n                                            wavelength = self.wavelengths[channel_id],\n                                            distance = distance_back,\n                                            device = self.device,\n                                            propagation_type = self.propagation_type,\n                                            samples = self.aperture_samples,\n                                            scale = self.resolution_factor\n                                           )\n            H = H_forward * H_back\n        self.kernels[depth_id, channel_id] = H\n        self.generated_kernels[depth_id, channel_id] = True\n    else:\n        H = self.kernels[depth_id, channel_id].detach().clone()\n    field_scale = input_field\n    field_scale_padded = zero_pad(field_scale)\n    output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)\n    output_field = crop_center(output_field_padded)\n    return output_field\n

This sample unit test provides an example use case of the holographic display class.

test_learn_wave_holographic_display.py
\n

Let us also examine how the reconstructed images look like at the image plane.

Reconstructed phase-only hologram at two image plane, generated using \"test/test_learn_wave_holographic_display.py\".

You may also be curious about how these holograms would look like in an actual holographic display, here we provide a sample gallery filled with photographs captured from our holographic display:

Photographs of holograms captured using the holographic display in Computational Light Laboratory"},{"location":"course/computer_generated_holography/#conclusion","title":"Conclusion","text":"

Informative

Holography offers new frontiers as an emerging method in simulating light for various applications, including displays and cameras. We provide a basic introduction to Computer-Generated Holography and a simple understanding of holographic methods. A motivated reader could scale up from this knowledge to advance concepts in displays, cameras, visual perception, optical computing, and many other light-based applications.

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

  1. Max Born and Emil Wolf. Principles of optics: electromagnetic theory of propagation, interference and diffraction of light. Elsevier, 2013.\u00a0\u21a9

  2. Joseph W Goodman. Introduction to Fourier optics. Roberts and Company publishers, 2005.\u00a0\u21a9

  3. Koray Kavakli, David Robert Walton, Nick Antipa, Rafa\u0142 Mantiuk, Douglas Lanman, and Kaan Ak\u015fit. Optimizing vision and visuals: lectures on cameras, displays and perception. In ACM SIGGRAPH 2022 Courses, pages 1\u201366. 2022.\u00a0\u21a9

  4. Jason D Schmidt. Numerical simulation of optical wave propagation with examples in matlab. (No Title), 2010.\u00a0\u21a9

  5. John C Heurtley. Scalar rayleigh\u2013sommerfeld and kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8):1003\u20131008, 1973.\u00a0\u21a9

  6. Maciej Sypek. Light propagation in the fresnel region. new numerical approach. Optics communications, 116(1-3):43\u201348, 1995.\u00a0\u21a9

  7. Kyoji Matsushima and Tomoyoshi Shimobaba. Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields. Optics express, 17(22):19662\u201319673, 2009.\u00a0\u21a9

  8. Wenhui Zhang, Hao Zhang, and Guofan Jin. Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range. Optics letters, 45(6):1543\u20131546, 2020.\u00a0\u21a9

  9. Wenhui Zhang, Hao Zhang, and Guofan Jin. Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product. Optics Letters, 45(16):4416\u20134419, 2020.\u00a0\u21a9

  10. Yifan Peng, Suyeon Choi, Nitish Padmanaban, and Gordon Wetzstein. Neural holography with camera-in-the-loop training. ACM Transactions on Graphics (TOG), 39(6):1\u201314, 2020.\u00a0\u21a9

  11. Praneeth Chakravarthula, Ethan Tseng, Tarun Srivastava, Henry Fuchs, and Felix Heide. Learned hardware-in-the-loop phase retrieval for holographic near-eye displays. ACM Transactions on Graphics (TOG), 39(6):1\u201318, 2020.\u00a0\u21a9

  12. Koray Kavakl\u0131, Hakan Urey, and Kaan Ak\u015fit. Learned holographic light transport. Applied Optics, 61(5):B50\u2013B55, 2022.\u00a0\u21a9\u21a9

  13. Gary J Swanson. Binary optics technology: the theory and design of multi-level diffractive optical elements. Technical Report, MASSACHUSETTS INST OF TECH LEXINGTON LINCOLN LAB, 1989.\u00a0\u21a9

  14. Herwig Kogelnik. Coupled wave theory for thick hologram gratings. Bell System Technical Journal, 48(9):2909\u20132947, 1969.\u00a0\u21a9

  15. Lingling Huang, Shuang Zhang, and Thomas Zentgraf. Metasurface holography: from fundamentals to applications. Nanophotonics, 7(6):1169\u20131190, 2018.\u00a0\u21a9

  16. Koray Kavakl\u0131, Yuta Itoh, Hakan Urey, and Kaan Ak\u015fit. Realistic defocus blur for multiplane computer-generated holography. In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), 418\u2013426. IEEE, 2023.\u00a0\u21a9

"},{"location":"course/fundamentals/","title":"Fundamentals in optimizing and learning light","text":"Narrate section"},{"location":"course/fundamentals/#fundamentals-and-standards","title":"Fundamentals and Standards","text":"

This chapter will reveal some important basic information you will use in the rest of this course. In addition, we will also introduce you to a structure where we establish some standards to decrease the chances of producing buggy or incompatible codes.

"},{"location":"course/fundamentals/#required-production-environment","title":"Required Production Environment","text":"

Informative \u00b7 Practical

We have provided some information in prerequisites. This information includes programming language requirements, required libraries, text editors, build environments, and operating system requirements. For installing our library, odak, we strongly advise using the version in the source repository. You can install odak from the source repository using your favorite terminal and operating system:

pip3 install git+https://github.com/kaanaksit/odak\n

Note that your production environment meaning your computer and required software for this course is important. To avoid wasting time in the next chapters and get the most from this lecture, please ensure that you have dedicated enough time to set everything up as it should.

"},{"location":"course/fundamentals/#production-standards","title":"Production Standards","text":"

Informative

In this course, you will be asked to code and implement simulations related to the physics of light. Your work, meaning your production, should strictly follow certain habits to help build better tools and developments.

"},{"location":"course/fundamentals/#subversion-and-revision-control","title":"Subversion and Revision Control","text":"

Informative \u00b7 Practical

As you develop your code for your future homework and projects, you will discover that many things could go wrong. For example, the hard drive that contains the only copy of your code could be damaged, or your most trusted friend (so-called) can claim that she compiled most of the work, and gets her credit for it, although that is not the case. These are just a few potential cases that may happen to you. On the other hand, in business life, poor code control can cause companies to lose money by releasing incorrect codes or researchers to lose their reputations as their work is challenging to replicate. How do you claim in that case that you did your part? What is the proper method to avoid losing data, time, effort, and motivation? In short, what is the way to stay out of trouble?

This is where the subversion, authoring, and revision control systems come into play, especially, for the example cases discussed in the previous paragraph. In today's world, Git is a widespread version control system adopted by major websites such as GitHub or Gitlab. We will not dive deep into how to use Git and all its features, but I will try to highlight parts that are essential for your workflow. I encourage you to use Git for creating a repository for every one of your tasks in the future. You can either keep this repository in your local and constantly back up somewhere else (suggested to people knowing what they are doing) or use these online services such as GitHub or Gitlab. I also encourage you to use the online services if you are a beginner.

For each operating system, installing Git has its processes, but for an Ubuntu operating system, it is as easy as typing the following commands in your terminal:

sudo apt install git\n

Let us imagine that you want to start a repository on GitHub. Make sure to create a private repository, and please only go public with any repository once you feel it is at a state where it can be shared with others. Once you have created your repository on GitHub, you can clone the repository using the following command in a terminal:

git clone REPLACEWITHLOCATIONOFREPO\n

You can find out about the repository's location by visiting the repository's website that you have created. The location is typically revealed by clicking the code button, as depicted in the below screenshot.

A screenshot showing how you can acquire the link for cloning a repository from GitHub.

For example, in the above case, the command should be updated with the following:

git clone https://github.com/kaanaksit/odak.git\n

If you want to share your private repository with someone you can go into the settings of your repository in its webpage and navigate to the collaborators section. This way, you can assign roles to your collaborators that best suit your scenario.

Secure your account

If you are using GitHub for your development, I highly encourage you to consider using two-factor authentication.

"},{"location":"course/fundamentals/#git-basics","title":"Git Basics","text":"

Informative \u00b7 Practical

If you want to add new files to your subversion control system, use the following in a terminal:

git add YOURFILE.jpeg\n

You may want to track the status of the files (whether they are added, deleted, etc.)

git status\n
And later, you can update the online copy (remote server or source) using the following:

git commit -am \"Explain what you add in a short comment.\"\ngit push\n

In some cases, you may want to include large binary files in your project, such as a paper, video, or any other media you want to archive within your project repository. For those cases, using just git may not be the best opinion, as Git works on creating a history of files and how they are changed at each commit, this history will likely be too bulky and oversized. Thus, cloning a repository could be slow when large binary files and Git come together. Assuming you are on an Ubuntu operating system, you can install the Large File Support (LFS) for Git by typing these commands in your terminal:

sudo apt install git-lfs\n

Once you have the LFS installed in your operating system, you can then go into your repository and enable LFS:

cd YOURREPOSITORY\ngit lfs install\n

Now is the time to let your LFS track specific files to avoid overcrowding your Git history. For example, you can track the *.pdf extension, meaning all the PDF files in your repository by typing the following command in your terminal:

git lfs track *.pdf\n

Finally, ensure the tracking information and LFS are copied to your remote/source repository. You can do that using the following commands in your terminal:

git add .gitattributes\ngit commit -am \"Enabling large file support.\"\ngit push\n

When projects expand in size, it's quite feasible for hundreds of individuals to collaborate within the same repository. This is particularly prevalent in sizable software development initiatives or open-source projects with a substantial contributor base. The branching system is frequently employed in these circumstances.

Consider you are in a software development team and you want to introduce new features or changes to a project without affecting the main or \"master\" branch. You need to firstly create a new branch by using the following command which creates a new branch named BRANCHNAME but does not switch to it. This new branch has the same contents as the current branch (a copy of the current branch).

git branch BRANCHNAME\n

Then you can switch to the new brach by using the command:

git checkout BRANCHNAME\n

Or use this command to create and switch to a new branch immediately

git checkout -b BRANCHNAME\n

After editing the new branch, you may want to update the changes to the master or main branch. This command merges the branch named BRANCHNAME into the current branch. You must resolve any conflicts to complete the merge.

git merge BRANCHNAME\n

We recommend an interactive, visual method for learning Git commands and branching online: learngitbranching. More information can be found in the offical Git documentation: Git docs.

"},{"location":"course/fundamentals/#coding-standards","title":"Coding Standards","text":"

Informative \u00b7 Practical

I encourage our readers to follow the methods of coding highlighted here. Following the methods that I am going to explain is not only crucial for developing replicable projects, but it is also vital for allowing other people to read your code with the least amount of hassle.

Where do I find out more about Python coding standards?

Python Enhancement Proposals documentation provides a great deal of information on modern ways to code in Python.

"},{"location":"course/fundamentals/#avoid-using-long-lines","title":"Avoid using long lines.","text":"

Please avoid having too many characters in one line. Let us start with a bad example:

def light_transport(wavelength, distances, resolution, propagation_type, polarization, input_field, output_field, angles):\n      pass\n      return results\n

As you can observe, the above function requires multiple inputs to be provided. Try making the inputs more readable by breaking lines and in some cases, you can also provide the requested type for an input and a default value to guide your users:

def light_transport(\n                    wavelength,\n                    distances,\n                    resolution,\n                    propagation_type : str, \n                    polarization = 'vertical',\n                    input_field = torch.rand(1, 1, 100, 100),\n                    output_field = torch.zeros(1, 1, 100, 100),\n                    angles= [0., 0., 0.]\n                   ):\n    pass\n    return results\n
"},{"location":"course/fundamentals/#leave-spaces-between-commands-variables-and-functions","title":"Leave spaces between commands, variables, and functions","text":"

Please avoid writing code like a train of characters. Here is a terrible coding example:

def addition(x,y,z):\n    result=2*y+z+x**2*3\n    return result\n

Please leave spaces after each comma, ,, and mathematical operation. So now, we can correct the above example as in below:

def addition(x, y, z):\n    result = 2 * y + z + x ** 2 * 3\n    return result\n

Please also leave two lines of space between the two functions. Here is a bad example again:

def add(x, y):\n    return x + y\ndef multiply(x, y):\n    return x * y\n

Instead, it should be:

def add(x, y):\n    return x + y\n\n\ndef multiply(x, y):\n    return x * y\n
"},{"location":"course/fundamentals/#add-documentation","title":"Add documentation","text":"

For your code, please make sure to add the necessary documentation. Here is a good example of doing that:

def add(x, y):\n    \"\"\"\n    A function to add two values together.\n\n    Parameters\n    ==========\n    x         : float\n                First input value.\n    y         : float\n                Second input value.\n\n    Returns\n    =======\n    result    : float\n                Result of the addition.\n    \"\"\"\n    result = x + y\n    return result\n
"},{"location":"course/fundamentals/#use-a-code-style-checker-and-validator","title":"Use a code-style checker and validator","text":"

There are also code-style checkers and code validators that you can adapt to your workflows when coding. One of these code-style checkers and validators I use in my projects is pyflakes. On an Ubuntu operating system, you can install pyflakes easily by typing these commands into your terminal:

sudo apt install python3-pyflakes\n

It could tell you about missing imports or undefined or unused variables. You can use it on any Python script very easily:

pyflakes3 sample.py\n

In addition, I use flake8 and autopep8 for standard code violations. To learn more about these, please read the code section of the contribution guide.

"},{"location":"course/fundamentals/#naming-variables","title":"Naming variables","text":"

When naming variables use lower case letters and make sure that the variables are named in an explanatory manner. Please also always use underscore as a replacement of space. For example if you are going to create a variable for storing reconstructed image at some image plane, you can name that variable as reconstructions_image_planes.

"},{"location":"course/fundamentals/#use-fewer-imports","title":"Use fewer imports","text":"

When it comes to importing libraries in your code, please make sure to use a minimal amount of libraries. Using a few libraries can help you keep your code robust and working over newer generations of libraries. Please stick to the libraries suggested in this course when coding for this course. If you need access to some other library, please do let us know!

"},{"location":"course/fundamentals/#fixing-bugs","title":"Fixing bugs","text":"

Often, you can encounter bugs in your code. To fix your code in such cases, I would like you to consider using a method called Rubber duck debugging or Rubber ducking. The basic idea is to be able to express your code to a third person or yourself line by line. Explaining line by line could help you see what is wrong with your code. I am sure there are many recipes for solving bugs in codes. I tried introducing you to one that works for me.

"},{"location":"course/fundamentals/#have-a-requirementstxt","title":"Have a requirements.txt","text":"

Please also make sure to have a requirements.txt in the root directory of your repository. For example, in this course your requirements.txt would look like this:

odak>=0.2.4\ntorch \n

This way, a future user of your code could install the required libraries by following a simple command in a terminal:

pip3 install -m requirements.txt \n
"},{"location":"course/fundamentals/#always-use-the-same-function-for-saving-and-loading","title":"Always use the same function for saving and loading","text":"

Most issues in every software project come from repetition. Imagine if you want to save and load images inside a code after some processing. If you rely on manually coding a save and load routine in every corner of the same code, it is likely that when you change one of these saving or loading routines, you must modify the others. In other words, do not rediscover what you have already known. Instead, turn it into a Lego brick you can use whenever needed. For saving and loading images, please rely on functions in odak to avoid any issues. For example, if I want to load a sample image called letter.jpeg, I can rely on this example:

import odak\nimage = odak.learn.tools.load_image(\n                                    'letter.jpeg',\n                                    torch_style = True, # (1)\n                                    normalizeby = 255. # (2)\n                                   )\n
  1. If you set this flag to True, the image will be loaded as [ch x m x n], where ch represents the number of color channels (e.g., typically three). In case of False, it will be loaded as [m x n x ch].
  2. If you provide a floating number here, the image to be loaded will be divived with that number. For example, if you have a 8-bit image (0-255) and if you provide normalizeby = 2.0, the maximum value that you can expect is 255 / 2. = 127.5.

Odak also provides a standard method for saving your torch tensors as image files:

odak.learn.tools.save_image(\n                            'copy.png',\n                            image,\n                            cmin = 0., # (1)\n                            cmax = 1., # (2)\n                            color_depth = 8 # (3)\n                           )\n
  1. Minimum expected value for torch tensor image.
  2. Maximum expected value for torch tensor image.
  3. Pixel depth of the image to be saved, default is 8-bit.

You may want to try the same code with different settings in some code development. In those cases, I create a separate settings folder in the root directory of my projects and add JSON files that I can load for testing different cases. To explain the case better, let us assume we will change the number of light sources in some simulations. Let's first assume that we create a settings file as settings/experiment_000.txt in the root directory and fill it with the following content:

{\n  \"light source\" : {\n                    \"count\" : 5,\n                    \"type\"  : \"LED\"\n                   }\n}\n

In the rest of my code, I can read, modify and save JSON files using odak functions:

import odak\nsettings = odak.tools.load_dictionary('./settings/experiment_000.txt')\nsettings['light source']['count'] = 10\nodak.tools.save_dictionary(settings, './settings/experiment_000.txt')\n

This way, you do not have to memorize the variables you used for every experiment you conducted with the same piece of code. You can have a dedicated settings file for each experiment.

"},{"location":"course/fundamentals/#create-unit-tests","title":"Create unit tests","text":"

Suppose your project is a library containing multiple valuable functions for developing other projects. In that case, I encourage you to create unit tests for your library so that whenever you update it, you can see if your updates break anything in that library. For this purpose, consider creating a test directory in the root folder of your repository. In that directory, you can create separate Python scripts for testing out various functions of your library. Say there is a function called add in your project MOSTAWESOMECODEEVER, so your test script test/test_add.py should look like this:

import MOSTAWESOMECODEEVER\n\ndef test():\n    ground_truth = 3 + 5\n    result = MOSTAWESOMECODEEVER.add(3, 5)\n    if ground_trurth == result:\n        assert True == True\n    assert False == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n

You may accumulate various unit tests in your test directory. To test them all before pushing them to your repository, you can rely on pytest. You can install pytest using the following command in your terminal:

pip3 install pytest\n

Once installed, you can navigate to your repository's root directory and call pytest to test things out:

cd MOSTAWESOMECODEEVER\npytest\n

If anything is wrong with your unit tests, which validate your functions, pytest will provide a detailed explanation.Suppose your project is a library containing multiple valuable functions for developing other projects. In that case, I encourage you to create unit tests for your library so that whenever you update it, you can see if your updates break anything in that library. For this purpose, consider creating a test directory in the root folder of your repository. In that directory, you can create separate Python scripts for testing out various functions of your library. Say there is a function called add in your project MOSTAWESOMECODEEVER, so your test script test/test_add.py should look like this:

import MOSTAWESOMECODEEVER\n\ndef test():\n    ground_truth = 3 + 5\n    result = MOSTAWESOMECODEEVER.add(3, 5)\n    if ground_trurth == result:\n        assert True == True\n    assert False == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n

You may accumulate various unit tests in your test directory. To test them all before pushing them to your repository, you can rely on pytest. You can install pytest using the following command in your terminal:

pip3 install pytest\n

Once installed, you can navigate to your repository's root directory and call pytest to test things out:

cd MOSTAWESOMECODEEVER\npytest\n

If anything is wrong with your unit tests, which validate your functions, pytest will provide a detailed explanation.

"},{"location":"course/fundamentals/#set-a-licence","title":"Set a licence","text":"

If you want to distribute your code online, consider adding a license to avoid having difficulties related to sharing with others. In other words, you can add LICENSE.txt in the root directory of your repository. To determine which license works best for you, consider visiting this guideline. When choosing a license for your project, consider tinkering about whether you agree people are building a product out of your work or derivate, etc.

Lab work: Prepare a project repository

Please prepare a sample repository on GitHub using the information provided in the above sections. Here are some sample files that may inspire you and help you structure your project in good order:

main.py LICENSE.txt requirements.txt THANKS.txt CODE_OF_CONDUCT.md
import odak\nimport torch\nimport sys\n\n\ndef main():\n    print('your codebase')\n\n\nif __name__ == '__main__':\n    sys.exit(main())\n
LICENSE.txt
Mozilla Public License Version 2.0\n==================================\n\n1. Definitions\n--------------\n\n1.1. \"Contributor\"\n    means each individual or legal entity that creates, contributes to\n    the creation of, or owns Covered Software.\n\n1.2. \"Contributor Version\"\n    means the combination of the Contributions of others (if any) used\n    by a Contributor and that particular Contributor's Contribution.\n\n1.3. \"Contribution\"\n    means Covered Software of a particular Contributor.\n\n1.4. \"Covered Software\"\n    means Source Code Form to which the initial Contributor has attached\n    the notice in Exhibit A, the Executable Form of such Source Code\n    Form, and Modifications of such Source Code Form, in each case\n    including portions thereof.\n\n1.5. \"Incompatible With Secondary Licenses\"\n    means\n\n    (a) that the initial Contributor has attached the notice described\n        in Exhibit B to the Covered Software; or\n\n    (b) that the Covered Software was made available under the terms of\n        version 1.1 or earlier of the License, but not also under the\n        terms of a Secondary License.\n\n1.6. \"Executable Form\"\n    means any form of the work other than Source Code Form.\n\n1.7. \"Larger Work\"\n    means a work that combines Covered Software with other material, in \n    a separate file or files, that is not Covered Software.\n\n1.8. \"License\"\n    means this document.\n\n1.9. \"Licensable\"\n    means having the right to grant, to the maximum extent possible,\n    whether at the time of the initial grant or subsequently, any and\n    all of the rights conveyed by this License.\n\n1.10. \"Modifications\"\n    means any of the following:\n\n    (a) any file in Source Code Form that results from an addition to,\n        deletion from, or modification of the contents of Covered\n        Software; or\n\n    (b) any new file in Source Code Form that contains any Covered\n        Software.\n\n1.11. \"Patent Claims\" of a Contributor\n    means any patent claim(s), including without limitation, method,\n    process, and apparatus claims, in any patent Licensable by such\n    Contributor that would be infringed, but for the grant of the\n    License, by the making, using, selling, offering for sale, having\n    made, import, or transfer of either its Contributions or its\n    Contributor Version.\n\n1.12. \"Secondary License\"\n    means either the GNU General Public License, Version 2.0, the GNU\n    Lesser General Public License, Version 2.1, the GNU Affero General\n    Public License, Version 3.0, or any later versions of those\n    licenses.\n\n1.13. \"Source Code Form\"\n    means the form of the work preferred for making modifications.\n\n1.14. \"You\" (or \"Your\")\n    means an individual or a legal entity exercising rights under this\n    License. For legal entities, \"You\" includes any entity that\n    controls, is controlled by, or is under common control with You. For\n    purposes of this definition, \"control\" means (a) the power, direct\n    or indirect, to cause the direction or management of such entity,\n    whether by contract or otherwise, or (b) ownership of more than\n    fifty percent (50%) of the outstanding shares or beneficial\n    ownership of such entity.\n\n2. License Grants and Conditions\n--------------------------------\n\n2.1. Grants\n\nEach Contributor hereby grants You a world-wide, royalty-free,\nnon-exclusive license:\n\n(a) under intellectual property rights (other than patent or trademark)\n    Licensable by such Contributor to use, reproduce, make available,\n    modify, display, perform, distribute, and otherwise exploit its\n    Contributions, either on an unmodified basis, with Modifications, or\n    as part of a Larger Work; and\n\n(b) under Patent Claims of such Contributor to make, use, sell, offer\n    for sale, have made, import, and otherwise transfer either its\n    Contributions or its Contributor Version.\n\n2.2. Effective Date\n\nThe licenses granted in Section 2.1 with respect to any Contribution\nbecome effective for each Contribution on the date the Contributor first\ndistributes such Contribution.\n\n2.3. Limitations on Grant Scope\n\nThe licenses granted in this Section 2 are the only rights granted under\nthis License. No additional rights or licenses will be implied from the\ndistribution or licensing of Covered Software under this License.\nNotwithstanding Section 2.1(b) above, no patent license is granted by a\nContributor:\n\n(a) for any code that a Contributor has removed from Covered Software;\n    or\n\n(b) for infringements caused by: (i) Your and any other third party's\n    modifications of Covered Software, or (ii) the combination of its\n    Contributions with other software (except as part of its Contributor\n    Version); or\n\n(c) under Patent Claims infringed by Covered Software in the absence of\n    its Contributions.\n\nThis License does not grant any rights in the trademarks, service marks,\nor logos of any Contributor (except as may be necessary to comply with\nthe notice requirements in Section 3.4).\n\n2.4. Subsequent Licenses\n\nNo Contributor makes additional grants as a result of Your choice to\ndistribute the Covered Software under a subsequent version of this\nLicense (see Section 10.2) or under the terms of a Secondary License (if\npermitted under the terms of Section 3.3).\n\n2.5. Representation\n\nEach Contributor represents that the Contributor believes its\nContributions are its original creation(s) or it has sufficient rights\nto grant the rights to its Contributions conveyed by this License.\n\n2.6. Fair Use\n\nThis License is not intended to limit any rights You have under\napplicable copyright doctrines of fair use, fair dealing, or other\nequivalents.\n\n2.7. Conditions\n\nSections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted\nin Section 2.1.\n\n3. Responsibilities\n-------------------\n\n3.1. Distribution of Source Form\n\nAll distribution of Covered Software in Source Code Form, including any\nModifications that You create or to which You contribute, must be under\nthe terms of this License. You must inform recipients that the Source\nCode Form of the Covered Software is governed by the terms of this\nLicense, and how they can obtain a copy of this License. You may not\nattempt to alter or restrict the recipients' rights in the Source Code\nForm.\n\n3.2. Distribution of Executable Form\n\nIf You distribute Covered Software in Executable Form then:\n\n(a) such Covered Software must also be made available in Source Code\n    Form, as described in Section 3.1, and You must inform recipients of\n    the Executable Form how they can obtain a copy of such Source Code\n    Form by reasonable means in a timely manner, at a charge no more\n    than the cost of distribution to the recipient; and\n\n(b) You may distribute such Executable Form under the terms of this\n    License, or sublicense it under different terms, provided that the\n    license for the Executable Form does not attempt to limit or alter\n    the recipients' rights in the Source Code Form under this License.\n\n3.3. Distribution of a Larger Work\n\nYou may create and distribute a Larger Work under terms of Your choice,\nprovided that You also comply with the requirements of this License for\nthe Covered Software. If the Larger Work is a combination of Covered\nSoftware with a work governed by one or more Secondary Licenses, and the\nCovered Software is not Incompatible With Secondary Licenses, this\nLicense permits You to additionally distribute such Covered Software\nunder the terms of such Secondary License(s), so that the recipient of\nthe Larger Work may, at their option, further distribute the Covered\nSoftware under the terms of either this License or such Secondary\nLicense(s).\n\n3.4. Notices\n\nYou may not remove or alter the substance of any license notices\n(including copyright notices, patent notices, disclaimers of warranty,\nor limitations of liability) contained within the Source Code Form of\nthe Covered Software, except that You may alter any license notices to\nthe extent required to remedy known factual inaccuracies.\n\n3.5. Application of Additional Terms\n\nYou may choose to offer, and to charge a fee for, warranty, support,\nindemnity or liability obligations to one or more recipients of Covered\nSoftware. However, You may do so only on Your own behalf, and not on\nbehalf of any Contributor. You must make it absolutely clear that any\nsuch warranty, support, indemnity, or liability obligation is offered by\nYou alone, and You hereby agree to indemnify every Contributor for any\nliability incurred by such Contributor as a result of warranty, support,\nindemnity or liability terms You offer. You may include additional\ndisclaimers of warranty and limitations of liability specific to any\njurisdiction.\n\n4. Inability to Comply Due to Statute or Regulation\n---------------------------------------------------\n\nIf it is impossible for You to comply with any of the terms of this\nLicense with respect to some or all of the Covered Software due to\nstatute, judicial order, or regulation then You must: (a) comply with\nthe terms of this License to the maximum extent possible; and (b)\ndescribe the limitations and the code they affect. Such description must\nbe placed in a text file included with all distributions of the Covered\nSoftware under this License. Except to the extent prohibited by statute\nor regulation, such description must be sufficiently detailed for a\nrecipient of ordinary skill to be able to understand it.\n\n5. Termination\n--------------\n\n5.1. The rights granted under this License will terminate automatically\nif You fail to comply with any of its terms. However, if You become\ncompliant, then the rights granted under this License from a particular\nContributor are reinstated (a) provisionally, unless and until such\nContributor explicitly and finally terminates Your grants, and (b) on an\nongoing basis, if such Contributor fails to notify You of the\nnon-compliance by some reasonable means prior to 60 days after You have\ncome back into compliance. Moreover, Your grants from a particular\nContributor are reinstated on an ongoing basis if such Contributor\nnotifies You of the non-compliance by some reasonable means, this is the\nfirst time You have received notice of non-compliance with this License\nfrom such Contributor, and You become compliant prior to 30 days after\nYour receipt of the notice.\n\n5.2. If You initiate litigation against any entity by asserting a patent\ninfringement claim (excluding declaratory judgment actions,\ncounter-claims, and cross-claims) alleging that a Contributor Version\ndirectly or indirectly infringes any patent, then the rights granted to\nYou by any and all Contributors for the Covered Software under Section\n2.1 of this License shall terminate.\n\n5.3. In the event of termination under Sections 5.1 or 5.2 above, all\nend user license agreements (excluding distributors and resellers) which\nhave been validly granted by You or Your distributors under this License\nprior to termination shall survive termination.\n\n************************************************************************\n*                                                                      *\n*  6. Disclaimer of Warranty                                           *\n*  -------------------------                                           *\n*                                                                      *\n*  Covered Software is provided under this License on an \"as is\"       *\n*  basis, without warranty of any kind, either expressed, implied, or  *\n*  statutory, including, without limitation, warranties that the       *\n*  Covered Software is free of defects, merchantable, fit for a        *\n*  particular purpose or non-infringing. The entire risk as to the     *\n*  quality and performance of the Covered Software is with You.        *\n*  Should any Covered Software prove defective in any respect, You     *\n*  (not any Contributor) assume the cost of any necessary servicing,   *\n*  repair, or correction. This disclaimer of warranty constitutes an   *\n*  essential part of this License. No use of any Covered Software is   *\n*  authorized under this License except under this disclaimer.         *\n*                                                                      *\n************************************************************************\n\n************************************************************************\n*                                                                      *\n*  7. Limitation of Liability                                          *\n*  --------------------------                                          *\n*                                                                      *\n*  Under no circumstances and under no legal theory, whether tort      *\n*  (including negligence), contract, or otherwise, shall any           *\n*  Contributor, or anyone who distributes Covered Software as          *\n*  permitted above, be liable to You for any direct, indirect,         *\n*  special, incidental, or consequential damages of any character      *\n*  including, without limitation, damages for lost profits, loss of    *\n*  goodwill, work stoppage, computer failure or malfunction, or any    *\n*  and all other commercial damages or losses, even if such party      *\n*  shall have been informed of the possibility of such damages. This   *\n*  limitation of liability shall not apply to liability for death or   *\n*  personal injury resulting from such party's negligence to the       *\n*  extent applicable law prohibits such limitation. Some               *\n*  jurisdictions do not allow the exclusion or limitation of           *\n*  incidental or consequential damages, so this exclusion and          *\n*  limitation may not apply to You.                                    *\n*                                                                      *\n************************************************************************\n\n8. Litigation\n-------------\n\nAny litigation relating to this License may be brought only in the\ncourts of a jurisdiction where the defendant maintains its principal\nplace of business and such litigation shall be governed by laws of that\njurisdiction, without reference to its conflict-of-law provisions.\nNothing in this Section shall prevent a party's ability to bring\ncross-claims or counter-claims.\n\n9. Miscellaneous\n----------------\n\nThis License represents the complete agreement concerning the subject\nmatter hereof. If any provision of this License is held to be\nunenforceable, such provision shall be reformed only to the extent\nnecessary to make it enforceable. Any law or regulation which provides\nthat the language of a contract shall be construed against the drafter\nshall not be used to construe this License against a Contributor.\n\n10. Versions of the License\n---------------------------\n\n10.1. New Versions\n\nMozilla Foundation is the license steward. Except as provided in Section\n10.3, no one other than the license steward has the right to modify or\npublish new versions of this License. Each version will be given a\ndistinguishing version number.\n\n10.2. Effect of New Versions\n\nYou may distribute the Covered Software under the terms of the version\nof the License under which You originally received the Covered Software,\nor under the terms of any subsequent version published by the license\nsteward.\n\n10.3. Modified Versions\n\nIf you create software not governed by this License, and you want to\ncreate a new license for such software, you may create and use a\nmodified version of this License if you rename the license and remove\nany references to the name of the license steward (except to note that\nsuch modified license differs from this License).\n\n10.4. Distributing Source Code Form that is Incompatible With Secondary\nLicenses\n\nIf You choose to distribute Source Code Form that is Incompatible With\nSecondary Licenses under the terms of this version of the License, the\nnotice described in Exhibit B of this License must be attached.\n\nExhibit A - Source Code Form License Notice\n-------------------------------------------\n\n  This Source Code Form is subject to the terms of the Mozilla Public\n  License, v. 2.0. If a copy of the MPL was not distributed with this\n  file, You can obtain one at http://mozilla.org/MPL/2.0/.\n\nIf it is not possible or desirable to put the notice in a particular\nfile, then You may include the notice in a location (such as a LICENSE\nfile in a relevant directory) where a recipient would be likely to look\nfor such a notice.\n\nYou may add additional accurate notices of copyright ownership.\n\nExhibit B - \"Incompatible With Secondary Licenses\" Notice\n---------------------------------------------------------\n\n  This Source Code Form is \"Incompatible With Secondary Licenses\", as\n  defined by the Mozilla Public License, v. 2.0.\n
requirements.txt
opencv-python>=4.10.0.84\nnumpy>=1.26.4\ntorch>=2.3.0\nplyfile>=1.0.3\ntqdm>=4.66.4\n
THANKS.txt
Ahmet Hamdi G\u00fczel\nAhmet Serdar Karadeniz\nDavid Robert Walton\nDavid Santiago Morales Norato\nHenry Kam\nDo\u011fa Y\u0131lmaz\nJeanne Beyazian\nJialun Wu\nJosef Spjut\nKoray Kavakl\u0131\nLiang Shi\nMustafa Do\u011fa Do\u011fan\nPraneeth Chakravarthula\nRunze Zhu\nWeijie Xie\nYujie Wang\nYuta Itoh\nZiyang Chen\nYicheng Zhan\n
CODE_OF_CONDUCT.md
# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participation in our\ncommunity a harassment-free experience for everyone, regardless of age, body\nsize, visible or invisible disability, ethnicity, sex characteristics, gender\nidentity and expression, level of experience, education, socio-economic status,\nnationality, personal appearance, race, religion, or sexual identity\nand orientation.\n\nWe pledge to act and interact in ways that contribute to an open, welcoming,\ndiverse, inclusive, and healthy community.\n\n## Our Standards\n\nExamples of behavior that contributes to a positive environment for our\ncommunity include:\n\n* Demonstrating empathy and kindness toward other people\n* Being respectful of differing opinions, viewpoints, and experiences\n* Giving and gracefully accepting constructive feedback\n* Accepting responsibility and apologizing to those affected by our mistakes,\n  and learning from the experience\n* Focusing on what is best not just for us as individuals, but for the\n  overall community\n\nExamples of unacceptable behavior include:\n\n* The use of sexualized language or imagery, and sexual attention or\n  advances of any kind\n* Trolling, insulting or derogatory comments, and personal or political attacks\n* Public or private harassment\n* Publishing others' private information, such as a physical or email\n  address, without their explicit permission\n* Other conduct which could reasonably be considered inappropriate in a\n  professional setting\n\n## Enforcement Responsibilities\n\nCommunity leaders are responsible for clarifying and enforcing our standards of\nacceptable behavior and will take appropriate and fair corrective action in\nresponse to any behavior that they deem inappropriate, threatening, offensive,\nor harmful.\n\nCommunity leaders have the right and responsibility to remove, edit, or reject\ncomments, commits, code, wiki edits, issues, and other contributions that are\nnot aligned to this Code of Conduct, and will communicate reasons for moderation\ndecisions when appropriate.\n\n## Scope\n\nThis Code of Conduct applies within all community spaces, and also applies when\nan individual is officially representing the community in public spaces.\nExamples of representing our community include using an official e-mail address,\nposting via an official social media account, or acting as an appointed\nrepresentative at an online or offline event.\n\n## Enforcement\n\nInstances of abusive, harassing, or otherwise unacceptable behavior may be\nreported to the community leaders responsible for enforcement at\n.\nAll complaints will be reviewed and investigated promptly and fairly.\n\nAll community leaders are obligated to respect the privacy and security of the\nreporter of any incident.\n\n## Enforcement Guidelines\n\nCommunity leaders will follow these Community Impact Guidelines in determining\nthe consequences for any action they deem in violation of this Code of Conduct:\n\n### 1. Correction\n\n**Community Impact**: Use of inappropriate language or other behavior deemed\nunprofessional or unwelcome in the community.\n\n**Consequence**: A private, written warning from community leaders, providing\nclarity around the nature of the violation and an explanation of why the\nbehavior was inappropriate. A public apology may be requested.\n\n### 2. Warning\n\n**Community Impact**: A violation through a single incident or series\nof actions.\n\n**Consequence**: A warning with consequences for continued behavior. No\ninteraction with the people involved, including unsolicited interaction with\nthose enforcing the Code of Conduct, for a specified period of time. This\nincludes avoiding interactions in community spaces as well as external channels\nlike social media. Violating these terms may lead to a temporary or\npermanent ban.\n\n### 3. Temporary Ban\n\n**Community Impact**: A serious violation of community standards, including\nsustained inappropriate behavior.\n\n**Consequence**: A temporary ban from any sort of interaction or public\ncommunication with the community for a specified period of time. No public or\nprivate interaction with the people involved, including unsolicited interaction\nwith those enforcing the Code of Conduct, is allowed during this period.\nViolating these terms may lead to a permanent ban.\n\n### 4. Permanent Ban\n\n**Community Impact**: Demonstrating a pattern of violation of community\nstandards, including sustained inappropriate behavior,  harassment of an\nindividual, or aggression toward or disparagement of classes of individuals.\n\n**Consequence**: A permanent ban from any sort of public interaction within\nthe community.\n\n## Attribution\n\nThis Code of Conduct is adapted from the [Contributor Covenant][homepage],\nversion 2.0, available at\nhttps://www.contributor-covenant.org/version/2/0/code_of_conduct.html.\n\nCommunity Impact Guidelines were inspired by [Mozilla's code of conduct\nenforcement ladder](https://github.com/mozilla/diversity).\n\n[homepage]: https://www.contributor-covenant.org\n\nFor answers to common questions about this code of conduct, see the FAQ at\nhttps://www.contributor-covenant.org/faq. Translations are available at\nhttps://www.contributor-covenant.org/translations.\n
"},{"location":"course/fundamentals/#background-review","title":"Background Review","text":"

Informative \u00b7 Media

Here, I will review some basic mathematical concepts using equations, images, or codes. Please note that you must understand these concepts to avoid difficulty following this course.

"},{"location":"course/fundamentals/#convolution-operation","title":"Convolution Operation","text":"

Convolution is a mathematical operation used as a building block for describing systems. It has proven to be highly effective in machine learning and deep learning. Convolution operation often denoted with a * symbol. Assume there is a matrix, A, which we want to convolve with some other matrix, also known as the kernel, K.

A sketch showing a matrix and a kernel to be convolved.

One can define such a matrix and a kernel using Torch in Python:

a = torch.tensor(\n                 [\n                  [1, 5, 9, 2, 3],\n                  [4, 8, 2, 3, 6],\n                  [7, 2, 0, 1, 3],\n                  [9, 6, 4, 2, 5],\n                  [2, 3, 5, 7, 4]\n                 ]\n                )\nk = torch.tensor(\n                 [\n                  [-1, 2, -3], \n                  [ 3, 5,  7], \n                  [-4, 9, -2]\n                 ]\n                )\n

To convolve these two matrices without losing information, we first have to go through a mathematical operation called zero padding.

A sketch showing zeropadding operating on a matrix.

To zeropad the matrix A, you can rely on Odak:

import odak\n\na_zeropad = odak.learn.tools.zero_pad(a, size = [7, 7])\n

Note that we pass here size as [7, 7], the logic of this is very simple. Our original matrix was five by five if you add a zero along two axis, you get seven by seven as the new requested size. Also note that our kernel is three by three. There could be cases where there is a larger kernel size. In those cases, you want to zeropad half the size of kernel (e.g., original size plus half the kernel size, a.shape[0] + k.shape[0] // 2). Now we choose the first element in the original matrix A, multiply it with the kernel, and add it to a matrix R. But note that we add the results of our summation by centring it with the original location of the first element.

A sketch showing the first step of a convolution operation.

We have to repeat this operation for each element in our original matrix and accummulate a result.

A sketch showing the second step of a convolution operation.

Note that there are other ways to describe and implement the convolution operation. Thus far, this definition formulates a simplistic description for convolution.

Lab work: Implement convolution operation using Numpy

There are three possible ways to implement convolution operation on a computer. The first one involves loops visiting each point in a given data. The second involves formulating a convolution operation as matrix multiplication, and the final one involves implementing convolution as a multiplication operation in the Fourier domain. Implement all these three methods using Jupyter Notebooks and visually prove that they are all functioning correctly with various kernels (e.g., convolving image with a kernel). Listed source files below may inspire your implementation in various means. Note that the below code is based on Torch but not Numpy.

odak.learn.tools.convolve2d odak.learn.tools.generate_2d_gaussian animation_convolution.py

Definition to convolve a field with a kernel by multiplying in frequency space.

Parameters:

  • field \u2013
          Input field with MxN shape.\n
  • kernel \u2013
          Input kernel with MxN shape.\n

Returns:

  • new_field ( tensor ) \u2013

    Convolved field.

Source code in odak/learn/tools/matrix.py
def convolve2d(field, kernel):\n    \"\"\"\n    Definition to convolve a field with a kernel by multiplying in frequency space.\n\n    Parameters\n    ----------\n    field       : torch.tensor\n                  Input field with MxN shape.\n    kernel      : torch.tensor\n                  Input kernel with MxN shape.\n\n    Returns\n    ----------\n    new_field   : torch.tensor\n                  Convolved field.\n    \"\"\"\n    fr = torch.fft.fft2(field)\n    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))\n    m, n = fr.shape\n    new_field = torch.real(torch.fft.ifft2(fr*fr2))\n    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))\n    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))\n    return new_field\n

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n
  • mu \u2013
            Mu of the Gaussian kernel along X and Y axes.\n
  • normalize \u2013
            If set True, normalize the output.\n

Returns:

  • kernel_2d ( tensor ) \u2013

    Generated Gaussian kernel.

Source code in odak/learn/tools/matrix.py
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):\n    \"\"\"\n    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n    mu            : list\n                    Mu of the Gaussian kernel along X and Y axes.\n    normalize     : bool\n                    If set True, normalize the output.\n\n    Returns\n    ----------\n    kernel_2d     : torch.tensor\n                    Generated Gaussian kernel.\n    \"\"\"\n    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])\n    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    if nsigma[0] == 0:\n        nsigma[0] = 1e-5\n    if nsigma[1] == 0:\n        nsigma[1] = 1e-5\n    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))\n    if normalize:\n        kernel_2d = kernel_2d / kernel_2d.max()\n    return kernel_2d\n
animation_convolution.py
import odak\nimport torch\nimport sys\n\n\ndef main():\n    filename_image = '../media/10591010993_80c7cb37a6_c.jpg'\n    image = odak.learn.tools.load_image(filename_image, normalizeby = 255., torch_style = True)[0:3].unsqueeze(0)\n    kernel = odak.learn.tools.generate_2d_gaussian(kernel_length = [12, 12], nsigma = [21, 21])\n    kernel = kernel / kernel.max()\n    result = torch.zeros_like(image)\n    result = odak.learn.tools.zero_pad(result, size = [image.shape[-2] + kernel.shape[0], image.shape[-1] + kernel.shape[1]])\n    step = 0\n    for i in range(image.shape[-2]):\n        for j in range(image.shape[-1]):\n            for ch in range(image.shape[-3]):\n                element = image[:, ch, i, j]\n                add = kernel * element\n                result[:, ch, i : i + kernel.shape[0], j : j + kernel.shape[1]] += add\n            if (i * image.shape[-1] + j) % 1e4 == 0:\n                filename = 'step_{:04d}.png'.format(step)\n                odak.learn.tools.save_image( filename, result, cmin = 0., cmax = 100.)\n                step += 1\n    cmd = ['convert', '-delay', '1', '-loop', '0', '*.png', '../media/convolution_animation.gif']\n    odak.tools.shell_command(cmd)\n    cmd = ['rm', '*.png']\n    odak.tools.shell_command(cmd)\n\n\nif __name__ == '__main__':\n    sys.exit(main())\n

In summary, the convolution operation is heavily used in describing optical systems, computer vision-related algorithms, and state-of-the-art machine learning techniques. Thus, understanding this mathematical operation is extremely important not only for this course but also for undergraduate and graduate-level courses. As an example, let's see step by step how a sample image provided below is convolved:

An animation showing the steps of convolution operation.

and the original image is as below:

Original image before the convolution operation (Generated by Stable Diffusion).

Note that the source image shown above is generated with a generative model. As a side note, I strongly suggest you to have familiarity with several models for generating test images, audio or any other type of media. This way, you can remove your dependency to others in various means.

Lab work: Convolve an image with a Gaussian kernel

Using Odak and Torch, blur an image using a Gaussian kernel. Also try compiling an animation like the one shown above using Matplotlib. Use the below solution as a last resort, try compiling your code. The code below is tested under Ubuntu operating system.

animation_convolution.py animation_convolution.py
import odak\nimport torch\nimport sys\n\n\ndef main():\n    filename_image = '../media/10591010993_80c7cb37a6_c.jpg'\n    image = odak.learn.tools.load_image(filename_image, normalizeby = 255., torch_style = True)[0:3].unsqueeze(0)\n    kernel = odak.learn.tools.generate_2d_gaussian(kernel_length = [12, 12], nsigma = [21, 21])\n    kernel = kernel / kernel.max()\n    result = torch.zeros_like(image)\n    result = odak.learn.tools.zero_pad(result, size = [image.shape[-2] + kernel.shape[0], image.shape[-1] + kernel.shape[1]])\n    step = 0\n    for i in range(image.shape[-2]):\n        for j in range(image.shape[-1]):\n            for ch in range(image.shape[-3]):\n                element = image[:, ch, i, j]\n                add = kernel * element\n                result[:, ch, i : i + kernel.shape[0], j : j + kernel.shape[1]] += add\n            if (i * image.shape[-1] + j) % 1e4 == 0:\n                filename = 'step_{:04d}.png'.format(step)\n                odak.learn.tools.save_image( filename, result, cmin = 0., cmax = 100.)\n                step += 1\n    cmd = ['convert', '-delay', '1', '-loop', '0', '*.png', '../media/convolution_animation.gif']\n    odak.tools.shell_command(cmd)\n    cmd = ['rm', '*.png']\n    odak.tools.shell_command(cmd)\n\n\nif __name__ == '__main__':\n    sys.exit(main())\n
"},{"location":"course/fundamentals/#gradient-descent-optimizers","title":"Gradient Descent Optimizers","text":"

Throughout this course, we will have to optimize variables to generate a solution for our problems. Thus, we need a scalable method to optimize various variables in future problems and tasks. We will not review optimizers in this section but provide a working solution. You can learn more about optimizers through other courses offered within our curriculum or through suggested readings. State-of-the-art Gradient Descent (GD) optimizers could play a key role here. Significantly, Stochastic Gradient Descent (SGD) optimizers can help resolve our problems in the future with a reasonable memory footprint. This is because GD updates its weights by visiting every sample in a dataset, whereas SGD can update using only randomly chosen data from that dataset. Thus, SGD requires less memory for each update.

Where can I read more about the state-of-the-art Stochastic Gradient Descent optimizer?

To learn more, please read Paszke, Adam, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. \"Automatic differentiation in pytorch.\" (2017). 1

Would you like to code your Gradient Descent based optimizer ground up?

In case you are interested in coding your Gradient Descent-based optimizer from the ground up, consider watching this tutorial online where I code the optimizer using only Numpy: If you want to learn more about odak's built-in functions on the matter, visit the below unit test script:

test_fit_gradient_descent_1d.py test_fit_gradient_descent_1d.py
import numpy as np\nimport sys\nimport odak\n\n\ndef gradient_function(x, y, function, parameters):\n    solution = function(x, parameters)\n    gradient = np.array([\n                         -2 * x**2 * (y - solution),\n                         -2 * x * (y- solution),\n                         -2 * (y - solution)\n                        ])\n    return gradient\n\n\ndef function(x, parameters):\n    y = parameters[0] * x**2 + parameters[1] * x + parameters[2]\n    return y\n\n\ndef l2_loss(a, b):\n    loss = np.sum((a - b)**2)\n    return loss\n\n\ndef test():\n    x = np.linspace(0, 1., 20) \n    y = function(x, parameters=[2., 1., 10.])\n\n    learning_rate = 5e-1\n    iteration_number = 2000\n    initial_parameters = np.array([10., 10., 0.])\n    estimated_parameters = odak.fit.gradient_descent_1d(\n                                                        input_data=x,\n                                                        ground_truth_data=y,\n                                                        function=function,\n                                                        loss_function=l2_loss,\n                                                        gradient_function=gradient_function,\n                                                        parameters=initial_parameters,\n                                                        learning_rate=learning_rate,\n                                                        iteration_number=iteration_number\n                                                       )\n    assert True == True\n\n\nif __name__ == '__main__':\n   sys.exit(test())\n

Torch is a blessing for people that optimizes or trains with their algorithm. Torch also comes with a set of state-of-the-art optimizers. One of these optimizers is called the ADAM optimizer, torch.optim.Adam. Let's observe the below example to make sense of how this optimizer can help us to optimize various variables.

import torch\nimport odak  \nimport sys # (1)\n\n\ndef forward(x, m, n): # (2)\n    y = m * x + n\n    return y\n\n\ndef main():\n    m = torch.tensor([100.], requires_grad = True)\n    n = torch.tensor([0.], requires_grad = True) # (3)\n    x_vals = torch.tensor([1., 2., 3., 100.])\n    y_vals = torch.tensor([5., 6., 7., 101.]) # (4)\n    optimizer = torch.optim.Adam([m, n], lr = 5e1) # (5)\n    loss_function = torch.nn.MSELoss() # (6)\n    for step in range(1000):\n        optimizer.zero_grad() # (7)\n        y_estimate = forward(x_vals, m, n) # (8)\n        loss = loss_function(y_estimate, y_vals) # (9)\n        loss.backward(retain_graph = True)\n        optimizer.step() # (10)\n        print('Step: {}, Loss: {}'.format(step, loss.item()))\n    print(m, n)\n\n\nif __name__ == '__main__':\n    sys.exit(main())\n
  1. Required libraries are imported.
  2. Let's assume that we are aiming to fit a line to some data (y = mx + n).
  3. As we are aiming to fit a line, we have to find a proper m and n for our line (y = mx + n). Pay attention to the fact that we have to make these variables differentiable by setting requires_grad = True.
  4. Here is a sample dataset of X and Y values.
  5. We define an Adam optimizer and ask our optimizer to optimize m and n.
  6. We need some metric to identify if we are optimizer is optimizing correctly. Here, we choose a L2 norm (least mean square) as our metric.
  7. We clear graph before each iteration.
  8. We make our estimation for Y values using the most current m and n values suggested by the optimizer.
  9. We compare our estimation with original Y values to help our optimizer update m and n values.
  10. Loss and optimizer help us move in the right direction for updating m and n values.
"},{"location":"course/fundamentals/#conclusion","title":"Conclusion","text":"

We covered a lot of grounds in terms of coding standards, how to organize a project repository, and how basic things work in odak and Torch. Please ensure you understand the essential information in this section. Please note that we will use this information in this course's following sections and stages.

Consider revisiting this chapter

Remember that you can always revisit this chapter as you progress with the course and as you need it. This chapter is vital for establishing a means to complete your assignments and could help formulate a suitable base to collaborate and work with my research group in the future or other experts in the field.

Did you know that Computer Science misses basic tool education?

The classes that Computer Science programs offer around the globe are commonly missing basic tool education. Students often spend a large amount of time to learn tools while they are also learning an advanced topic. This section of our course gave you a quick overview. But you may want to go beyond and learn more about many more basic aspects of Computer Science such as using shell tools, editors, metaprogramming or security. The missing semester of your CS education offers an online resource for you to follow up and learn more. The content of the mentioned course is mostly developed by instructors from Massachusetts Institute of Technology.

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

  1. Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. NIPS 2017 Workshop Autodiff, 2017.\u00a0\u21a9

"},{"location":"course/geometric_optics/","title":"Modeling light with rays","text":"Narrate section"},{"location":"course/geometric_optics/#modeling-light-with-rays","title":"Modeling light with rays","text":"

Modeling light plays a crucial role in describing events based on light and helps designing mechanisms based on light (e.g., Realistic graphics in a video game, display or camera). This chapter introduces the most basic description of light using geometric rays, also known as raytracing. Raytracing has a long history, from ancient times to current Computer Graphics. Here, we will not cover the history of raytracing. Instead, we will focus on how we implement simulations to build \"things\" with raytracing in the future. As we provide algorithmic examples to support our descriptions, readers should be able to simulate light on their computers using the provided descriptions.

Are there other good resources on modeling light with rays?

When I first started coding Odak, the first paper I read was on raytracing. Thus, I recommend that paper for any starter:

  • Spencer, G. H., and M. V. R. K. Murty. \"General ray-tracing procedure.\" JOSA 52, no. 6 (1962): 672-678. 1

Beyond this paper, there are several resources that I can recommend for curious readers:

  • Shirley, Peter. \"Ray tracing in one weekend.\" Amazon Digital Services LLC 1 (2018): 4. 2
  • Morgan McGuire (2021). The Graphics Codex. Casual Effects. 3
"},{"location":"course/geometric_optics/#ray-description","title":"Ray description","text":"

Informative \u00b7 Practical

We have to define what \"a ray\" is. A ray has a starting point in Euclidean space (\\(x_0, y_0, z_0 \\in \\mathbb{R}\\)). We also have to define direction cosines to provide the directions for rays. Direction cosines are three angles of a ray between the XYZ axis and that ray (\\(\\theta_x, \\theta_y, \\theta_z \\in \\mathbb{R}\\)). To calculate direction cosines, we must choose a point on that ray as \\(x_1, y_1,\\) and \\(z_1\\) and we calculate its distance to the starting point of \\(x_0, y_0\\) and \\(z_0\\):

\\[ x_{distance} = x_1 - x_0, \\\\ y_{distance} = y_1 - y_0, \\\\ z_{distance} = z_1 - z_0. \\]

Then, we can also calculate the Euclidian distance between starting point and the point chosen:

\\[ s = \\sqrt{x_{distance}^2 + y_{distance}^2 + z_{distance}^2}. \\]

Thus, we describe each direction cosines as:

\\[ cos(\\theta_x) = \\frac{x_{distance}}{s}, \\\\ cos(\\theta_y) = \\frac{y_{distance}}{s}, \\\\ cos(\\theta_z) = \\frac{z_{distance}}{s}. \\]

Now that we know how to define a ray with a starting point, \\(x_0, y_0, z_0\\) and a direction cosine, \\(cos(\\theta_x), cos(\\theta_y), cos(\\theta_z)\\), let us carefully analyze the parameters, returns, and source code of the provided two following functions in odak dedicated to creating a ray or multiple rays.

odak.learn.raytracing.create_ray odak.learn.raytracing.create_ray_from_two_points

Definition to create a ray.

Parameters:

  • xyz \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • abg \u2013
           List that contains angles in degrees with respect to the X,Y and Z axes.\n       Size could be [1 x 3], [3], [m x 3].\n
  • direction \u2013
           If set to True, cosines of `abg` is not calculated.\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray. Size will be either [1 x 3] or [m x 3].

Source code in odak/learn/raytracing/ray.py
def create_ray(xyz, abg, direction = False):\n    \"\"\"\n    Definition to create a ray.\n\n    Parameters\n    ----------\n    xyz          : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    abg          : torch.tensor\n                   List that contains angles in degrees with respect to the X,Y and Z axes.\n                   Size could be [1 x 3], [3], [m x 3].\n    direction    : bool\n                   If set to True, cosines of `abg` is not calculated.\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray.\n                   Size will be either [1 x 3] or [m x 3].\n    \"\"\"\n    points = xyz\n    angles = abg\n    if len(xyz) == 1:\n        points = xyz.unsqueeze(0)\n    if len(abg) == 1:\n        angles = abg.unsqueeze(0)\n    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)\n    ray[:, 0] = points\n    if direction:\n        ray[:, 1] = abg\n    else:\n        ray[:, 1] = torch.cos(torch.deg2rad(abg))\n    return ray\n

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray or batch of rays.\n       Size could be [1 x 3], [3], [m x 3].\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray(s).

Source code in odak/learn/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    x1y1z1       : torch.tensor\n                   List that contains X,Y and Z ending locations of a ray or batch of rays.\n                   Size could be [1 x 3], [3], [m x 3].\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray(s).\n    \"\"\"\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.unsqueeze(0)\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.unsqueeze(0)\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5\n    s[s == 0] = float('nan')\n    cosines = torch.zeros_like(x0y0z0 * x1y1z1)\n    cosines[:, 0] = xdiff / s\n    cosines[:, 1] = ydiff / s\n    cosines[:, 2] = zdiff / s\n    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    return ray\n

In the future, we must find out where a ray lands after a certain amount of propagation distance for various purposes, which we will describe in this chapter. For that purpose, let us also create a utility function that propagates a ray to some distance, \\(d\\), using \\(x_0, y_0, z_0\\) and \\(cos(\\theta_x), cos(\\theta_y), cos(\\theta_z)\\):

\\[ x_{new} = x_0 + cos(\\theta_x) d,\\\\ y_{new} = y_0 + cos(\\theta_y) d,\\\\ z_{new} = z_0 + cos(\\theta_z) d. \\]

Let us also check the function provided below to understand its source code, parameters, and returns. This function will serve as a utility function to propagate a ray or a batch of rays in our future simulations.

odak.learn.raytracing.propagate_ray

Definition to propagate a ray at a certain given distance.

Parameters:

  • ray \u2013
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n
  • distance \u2013
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n

Returns:

  • new_ray ( tensor ) \u2013

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

Source code in odak/learn/raytracing/ray.py
def propagate_ray(ray, distance):\n    \"\"\"\n    Definition to propagate a ray at a certain given distance.\n\n    Parameters\n    ----------\n    ray        : torch.tensor\n                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n    distance   : torch.tensor\n                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n\n    Returns\n    ----------\n    new_ray    : torch.tensor\n                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(distance.shape) == 2:\n        distance = distance.squeeze(-1)\n    new_ray = torch.zeros_like(ray)\n    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]\n    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]\n    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]\n    return new_ray\n

It is now time for us to put what we have learned so far into an actual code. We can create many rays using the two functions, odak.learn.raytracing.create_ray_from_two_points and odak.learn.raytracing.create_ray. However, to do so, we need to have many points in both cases. For that purpose, let's carefully review this utility function provided below. This utility function can generate grid samples from a plane with some tilt, and we can also define the center of our samples to position points anywhere in Euclidean space.

odak.learn.tools.grid_sample

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( tensor ) \u2013

    Samples generated.

  • rotx ( tensor ) \u2013

    Rotation matrix at X axis.

  • roty ( tensor ) \u2013

    Rotation matrix at Y axis.

  • rotz ( tensor ) \u2013

    Rotation matrix at Z axis.

Source code in odak/learn/tools/sample.py
def grid_sample(\n                no = [10, 10],\n                size = [100., 100.], \n                center = [0., 0., 0.], \n                angles = [0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    -------\n    samples     : torch.tensor\n                  Samples generated.\n    rotx        : torch.tensor\n                  Rotation matrix at X axis.\n    roty        : torch.tensor\n                  Rotation matrix at Y axis.\n    rotz        : torch.tensor\n                  Rotation matrix at Z axis.\n    \"\"\"\n    center = torch.tensor(center)\n    angles = torch.tensor(angles)\n    size = torch.tensor(size)\n    samples = torch.zeros((no[0], no[1], 3))\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))\n    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)\n    return samples, rotx, roty, rotz\n

The below script provides a sample use case for the functions provided above. I also leave comments near some lines explaining the code in steps.

test_learn_ray_create_ray_from_two_points.py
import sys\nimport odak\nimport torch # (1)\n\n\ndef test(directory = 'test_output'):\n    odak.tools.check_directory(directory)\n    starting_point = torch.tensor([[5., 5., 0.]]) # (2)\n    end_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                       no = [2, 2], \n                                                       size = [20., 20.], \n                                                       center = [0., 0., 10.]\n                                                      ) # (3)\n    rays_from_points = odak.learn.raytracing.create_ray_from_two_points(\n                                                                        starting_point,\n                                                                        end_points\n                                                                       ) # (4)\n\n\n    starting_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                            no = [3, 3], \n                                                            size = [100., 100.], \n                                                            center = [0., 0., 10.],\n                                                           )\n    angles = torch.randn_like(starting_points) * 180. # (5)\n    rays_from_angles = odak.learn.raytracing.create_ray(\n                                                        starting_points,\n                                                        angles\n                                                       ) # (6)\n\n\n    distances = torch.ones(rays_from_points.shape[0]) * 12.5\n    propagated_rays = odak.learn.raytracing.propagate_ray(\n                                                          rays_from_points,\n                                                          distances\n                                                         ) # (7)\n\n\n\n\n    visualize = False # (8)\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.)\n        ray_diagram.add_point(starting_point, color = 'red')\n        ray_diagram.add_point(end_points[0], color = 'blue')\n        ray_diagram.add_line(starting_point, end_points[0], color = 'green')\n        x_axis = starting_point.clone()\n        x_axis[0, 0] = end_points[0, 0]\n        ray_diagram.add_point(x_axis, color = 'black')\n        ray_diagram.add_line(starting_point, x_axis, color = 'black', dash = 'dash')\n        y_axis = starting_point.clone()\n        y_axis[0, 1] = end_points[0, 1]\n        ray_diagram.add_point(y_axis, color = 'black')\n        ray_diagram.add_line(starting_point, y_axis, color = 'black', dash = 'dash')\n        z_axis = starting_point.clone()\n        z_axis[0, 2] = end_points[0, 2]\n        ray_diagram.add_point(z_axis, color = 'black')\n        ray_diagram.add_line(starting_point, z_axis, color = 'black', dash = 'dash')\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
  1. Required libraries are imported.
  2. Defining a starting point, in order X, Y and Z locations. Size of starting point could be s1] or [1, 1].
  3. Defining some end points on a plane in grid fashion.
  4. odak.learn.raytracing.create_ray_from_two_points is verified with an example! Let's move on to odak.learn.raytracing.create_ray.
  5. Creating starting points with odak.learn.tools.grid_sample and defining some angles as the direction using torch.randn. Note that the angles are in degrees.
  6. odak.learn.raytracing.create_ray is verified with an example!
  7. odak.learn.raytracing.propagate_a_ray is verified with an example!
  8. Set it to True to enable visualization.

The above code also has parts that are disabled (see visualize variable). We disabled these lines intentionally to avoid running it at every run. Let me talk about these disabled functions as well. Odak offers a tidy approach to simple visualizations through packages called Plotly and kaleido. To make these lines work by setting visualize = True, you must first install plotly in your work environment. This installation is as simple as pip3 install plotly kaleido in a Linux system. As you install these packages and enable these lines, the code will produce a visualization similar to the one below. Note that this is an interactive visualization where you can interact with your mouse clicks to rotate, shift, and zoom. In this visualization, we visualize a single ray (green line) starting from our defined starting point (red dot) and ending at one of the end_points (blue dot). We also highlight three axes with black lines to provide a reference frame. Although odak.visualize.plotly offers us methods to visualize rays quickly for debugging, it is highly suggested to stick to a low number of lines when using it (e.g., say not exceeding 100 rays in total). The proper way to draw many rays lies in modern path-tracing renderers such as Blender.

How can I learn more about more sophisticated renderers like Blender?

Blender is a widely used open-source renderer that comes with sophisticated features. It is user interface could be challenging for newcomers. A blog post published by SIGGRAPH Research Career Development Committee offers a neat entry-level post titled Rendering a paper figure with Blender written by Silvia Sell\u00e1n.

In addition to Blender, there are various renderers you may be happy to know about if you are curious about Computer Graphics. Mitsuba 3 is another sophisticated rendering system based on a SIGGRAPH paper titled Dr.Jit: A Just-In-Time Compiler for Differentiable Rendering 4 from Wenzel Jakob.

If you know any other, please share it with the class so that they also learn more about other renderers.

Challenge: Blender meets Odak

In light of the given information, we challenge readers to create a new submodule for Odak. Note that Odak has odak.visualize.blender submodule. However, at the time of this writing, this submodule works as a server that sends commands to a program that has to be manually triggered inside Blender. Odak seeks an upgrade to this submodule, where users can draw rays, meshes, or parametric surfaces easily in Blender with commands from Odak. This newly upgraded submodule should require no manual processes. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for your new submodule in docs/notes/odak_meets_blender.md.

"},{"location":"course/geometric_optics/#intersecting-rays-with-surfaces","title":"Intersecting rays with surfaces","text":"

Informative \u00b7 Practical

Rays we have described so far help us explore light and matter interactions. Often in simulations, these rays interact with surfaces. In a simulation environment for optical design, equations often describe surfaces continuously. These surface equations typically contain a number of parameters for defining surfaces. For example, let us consider a sphere, which follows a standard equation as follows,

\\[ r^2 = (x - x_0)^2 + (y - y_0)^2 + (z - z_0)^2, \\]

Where \\(r\\) represents the diameter of that sphere, \\(x_0, y_0, z_0\\) defines the center location of that sphere, and \\(x, y, z\\) are points on the surface of a sphere. When testing if a point is on a sphere, we use the above equation by inserting the point to be tested as \\(x, y, z\\) into that equation. In other words, to find a ray and sphere intersection, we must identify a distance that propagates our rays a certain amount and lends on a point on that sphere, and we can use the above sphere equation for identifying the intersection point of that rays. As long the surface equation is well degined, the same strategy can be used for any surfaces. In addition, if needed for future purposes (e.g., reflecting or refracting light off the surface of that sphere), we can also calculate the surface normal of that sphere by drawing a line by defining a ray starting from the center of that sphere and propagating towards the intersection point. Let us examine, how we can identify intersection points for a set of given rays and a sphere by examining the below function.

odak.learn.raytracing.intersect_w_sphere

Definition to find the intersection between ray(s) and sphere(s).

Parameters:

  • ray \u2013
                  Input ray(s).\n              Expected size is [1 x 2 x 3] or [m x 2 x 3].\n
  • sphere \u2013
                  Input sphere.\n              Expected size is [1 x 4].\n
  • learning_rate \u2013
                  Learning rate used in the optimizer for finding the propagation distances of the rays.\n
  • number_of_steps \u2013
                  Number of steps used in the optimizer.\n
  • error_threshold \u2013
                  The error threshold that will help deciding intersection or no intersection.\n

Returns:

  • intersecting_ray ( tensor ) \u2013

    Ray(s) that intersecting with the given sphere. Expected size is [n x 2 x 3], where n could be any real number.

  • intersecting_normal ( tensor ) \u2013

    Normal(s) for the ray(s) intersecting with the given sphere Expected size is [n x 2 x 3], where n could be any real number.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):\n    \"\"\"\n    Definition to find the intersection between ray(s) and sphere(s).\n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          Input ray(s).\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    sphere              : torch.tensor\n                          Input sphere.\n                          Expected size is [1 x 4].\n    learning_rate       : float\n                          Learning rate used in the optimizer for finding the propagation distances of the rays.\n    number_of_steps     : int\n                          Number of steps used in the optimizer.\n    error_threshold     : float\n                          The error threshold that will help deciding intersection or no intersection.\n\n    Returns\n    -------\n    intersecting_ray    : torch.tensor\n                          Ray(s) that intersecting with the given sphere.\n                          Expected size is [n x 2 x 3], where n could be any real number.\n    intersecting_normal : torch.tensor\n                          Normal(s) for the ray(s) intersecting with the given sphere\n                          Expected size is [n x 2 x 3], where n could be any real number.\n\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(sphere.shape) == 1:\n        sphere = sphere.unsqueeze(0)\n    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)\n    loss_l2 = torch.nn.MSELoss(reduction = 'sum')\n    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    \n    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)\n    for step in t:\n        optimizer.zero_grad()\n        propagated_ray = propagate_ray(ray, distance)\n        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)\n        loss = loss_l2(\n                       test,\n                       torch.zeros_like(test)\n                      )\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description('Sphere intersection loss: {}'.format(loss.item()))\n    check = test < error_threshold\n    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])\n    intersecting_normal = create_ray_from_two_points(\n                                                     sphere[:, 0:3],\n                                                     intersecting_ray[:, 0]\n                                                    )\n    return intersecting_ray, intersecting_normal, distance, check\n

The odak.learn.raytracing.intersect_w_sphere function uses an optimizer to identify intersection points for each ray. Instead, a function could have accomplished the task with a closed-form solution without iterating over the intersection test, which could have been much faster than the current function. If you are curious about how to fix the highlighted issue, you may want to see the challenge provided below.

Let us examine how we can use the provided sphere intersection function with an example provided at the end of this subsection.

test_learn_ray_intersect_w_a_sphere.py
import sys\nimport odak\nimport torch\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    starting_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                            no = [5, 5],\n                                                            size = [3., 3.],\n                                                            center = [0., 0., 0.]\n                                                           )\n    end_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                       no = [5, 5],\n                                                       size = [0.1, 0.1],\n                                                       center = [0., 0., 5.]\n                                                      )\n    rays = odak.learn.raytracing.create_ray_from_two_points(\n                                                            starting_points,\n                                                            end_points\n                                                           )\n    center = torch.tensor([[0., 0., 5.]])\n    radius = torch.tensor([[3.]])\n    sphere = odak.learn.raytracing.define_sphere(\n                                                 center = center,\n                                                 radius = radius\n                                                ) # (1)\n    intersecting_rays, intersecting_normals, _, check = odak.learn.raytracing.intersect_w_sphere(rays, sphere)\n\n\n    visualize = False # (2)\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.)\n        ray_diagram.add_point(rays[:, 0], color = 'blue')\n        ray_diagram.add_line(rays[:, 0][check == True], intersecting_rays[:, 0], color = 'blue')\n        ray_diagram.add_sphere(sphere, color = 'orange')\n        ray_diagram.add_point(intersecting_normals[:, 0], color = 'green')\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
  1. Here we provide an example use case for odak.learn.raytracing.intersect_w_sphere by providing a sphere and a batch of sample rays.
  2. Uncomment for running visualization.

Screenshow showing a sphere and ray intersections generated by \"test_learn_ray_intersect_w_a_sphere.py\" script.

This section shows us how to operate with known geometric shapes, precisely spheres. However, not every shape could be defined using parametric modeling (e.g., nonlinearities such as discontinuities on a surface). We will look into another method in the next section, an approach used by folks working in Computer Graphics.

Challenge: Raytracing arbitrary surfaces

In light of the given information, we challenge readers to create a new function inside odak.learn.raytracing submodule that replaces the current intersect_w_sphere function. In addition, the current unit test test/test_learn_ray_intersect_w_a_sphere.py has to adopt this new function. odak.learn.raytracing submodule also needs new functions for supporting arbitrary surfaces (parametric). New unit tests are needed to improve the submodule accordingly. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_arbitrary_surfaces.md.

"},{"location":"course/geometric_optics/#intersecting-rays-with-meshes","title":"Intersecting rays with meshes","text":"

Informative \u00b7 Practical

Parametric surfaces provide ease in defining shapes and geometries in various fields, including Optics and Computer Graphics. However, not every object in a given scene could easily be described using parametric surfaces. In many cases, including modern Computer Graphics, triangles formulate the smallest particle of an object or a shape. These triangles altogether form meshes that define objects and shapes. For this purpose, we will review source codes, parameters, and returns of three utility functions here. We will first review odak.learn.raytracing.intersect_w_surface to understand how one can calculate the intersection of a ray with a given plane. Later, we review odak.learn.raytracing.is_it_on_triangle function, which checks if an intersection point on a given surface is inside a triangle on that surface. Finally, we will review odak.learn.raytracing.intersect_w_triangle function. This last function combines both reviewed functions into a single function to identify the intersection between rays and a triangle.

odak.learn.raytracing.intersect_w_surface odak.learn.raytracing.is_it_on_triangle odak.learn.raytracing.intersect_w_triangle

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

Parameters:

  • ray \u2013
           A vector/ray.\n
  • points \u2013
           Set of points in X,Y and Z to define a planar surface.\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_surface(ray, points):\n    \"\"\"\n    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray.\n    points       : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface.\n\n    Returns\n    ----------\n    normal       : torch.tensor\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between starting point of a ray with it's intersection with a planar surface.\n    \"\"\"\n    normal = get_triangle_normal(points)\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(points.shape) == 2:\n        points = points.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    f = normal[:, 0] - ray[:, 0]\n    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T\n    new_normal = torch.zeros_like(ray)\n    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]\n    new_normal[:, 1] = normal[:, 1]\n    new_normal = torch.nan_to_num(\n                                  new_normal,\n                                  nan = float('nan'),\n                                  posinf = float('nan'),\n                                  neginf = float('nan')\n                                 )\n    distance = torch.nan_to_num(\n                                distance,\n                                nan = float('nan'),\n                                posinf = float('nan'),\n                                neginf = float('nan')\n                               )\n    return new_normal, distance\n

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True. For more details, visit: https://blackpawn.com/texts/pointinpoly/.

Parameters:

  • point_to_check \u2013
              Point(s) to check.\n          Expected size is [3], [1 x 3] or [m x 3].\n
  • triangle \u2013
              Triangle described with three points.\n          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n

Returns:

  • result ( tensor ) \u2013

    Is it on a triangle? Returns NaN if condition not satisfied. Expected size is [1] or [m] depending on the input.

Source code in odak/learn/raytracing/primitives.py
def is_it_on_triangle(point_to_check, triangle):\n    \"\"\"\n    Definition to check if a given point is inside a triangle. \n    If the given point is inside a defined triangle, this definition returns True.\n    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).\n\n    Parameters\n    ----------\n    point_to_check  : torch.tensor\n                      Point(s) to check.\n                      Expected size is [3], [1 x 3] or [m x 3].\n    triangle        : torch.tensor\n                      Triangle described with three points.\n                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n\n    Returns\n    -------\n    result          : torch.tensor\n                      Is it on a triangle? Returns NaN if condition not satisfied.\n                      Expected size is [1] or [m] depending on the input.\n    \"\"\"\n    if len(point_to_check.shape) == 1:\n        point_to_check = point_to_check.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    v0 = triangle[:, 2] - triangle[:, 0]\n    v1 = triangle[:, 1] - triangle[:, 0]\n    v2 = point_to_check - triangle[:, 0]\n    if len(v0.shape) == 1:\n        v0 = v0.unsqueeze(0)\n    if len(v1.shape) == 1:\n        v1 = v1.unsqueeze(0)\n    if len(v2.shape) == 1:\n        v2 = v2.unsqueeze(0)\n    dot00 = torch.mm(v0, v0.T)\n    dot01 = torch.mm(v0, v1.T)\n    dot02 = torch.mm(v0, v2.T) \n    dot11 = torch.mm(v1, v1.T)\n    dot12 = torch.mm(v1, v2.T)\n    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)\n    u = (dot11 * dot02 - dot01 * dot12) * invDenom\n    v = (dot00 * dot12 - dot01 * dot02) * invDenom\n    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)\n    return result\n

Definition to find intersection point of a ray with a triangle.

Parameters:

  • ray \u2013
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n
  • triangle \u2013
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection with the surface of triangle. This could also involve surface normals that are not on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle. Expected size is [1 x 1] or [m x 1] depending on the input.

  • intersecting_ray ( tensor ) \u2013

    Rays that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • intersecting_normal ( tensor ) \u2013

    Normals that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • check ( tensor ) \u2013

    A list that provides a bool as True or False for each ray used as input. A test to see is a ray could be on the given triangle. Expected size is [1] or [m].

Source code in odak/learn/raytracing/boundary.py
def intersect_w_triangle(ray, triangle):\n    \"\"\"\n    Definition to find intersection point of a ray with a triangle. \n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n    triangle            : torch.tensor\n                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n\n    Returns\n    ----------\n    normal              : torch.tensor\n                          Surface normal at the point of intersection with the surface of triangle.\n                          This could also involve surface normals that are not on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    distance            : float\n                          Distance in between a starting point of a ray and the intersection point with a given triangle.\n                          Expected size is [1 x 1] or [m x 1] depending on the input.\n    intersecting_ray    : torch.tensor\n                          Rays that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    intersecting_normal : torch.tensor\n                          Normals that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    check               : torch.tensor\n                          A list that provides a bool as True or False for each ray used as input.\n                          A test to see is a ray could be on the given triangle.\n                          Expected size is [1] or [m].\n    \"\"\"\n    if len(triangle.shape) == 2:\n       triangle = triangle.unsqueeze(0)\n    if len(ray.shape) == 2:\n       ray = ray.unsqueeze(0)\n    normal, distance = intersect_w_surface(ray, triangle)\n    check = is_it_on_triangle(normal[:, 0], triangle)\n    intersecting_ray = ray.unsqueeze(0)\n    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_ray = intersecting_ray[check == True]\n    intersecting_normal = normal.unsqueeze(0)\n    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_normal = intersecting_normal[check ==  True]\n    return normal, distance, intersecting_ray, intersecting_normal, check\n

Using the provided utility functions above, let us build an example below that helps us find intersections between a triangle and a batch of rays.

test_learn_ray_intersect_w_a_triangle.py
import sys\nimport odak\nimport torch\n\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    starting_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                            no = [5, 5],\n                                                            size = [10., 10.],\n                                                            center = [0., 0., 0.]\n                                                           )\n    end_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                       no = [5, 5],\n                                                       size = [6., 6.],\n                                                       center = [0., 0., 10.]\n                                                      )\n    rays = odak.learn.raytracing.create_ray_from_two_points(\n                                                            starting_points,\n                                                            end_points\n                                                           )\n    triangle = torch.tensor([[\n                              [-5., -5., 10.],\n                              [ 5., -5., 10.],\n                              [ 0.,  5., 10.]\n                            ]])\n    normals, distance, _, _, check = odak.learn.raytracing.intersect_w_triangle(\n                                                                                rays,\n                                                                                triangle\n                                                                               ) # (2)\n\n\n\n    visualize = False # (1)\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(line_width = 3., marker_size = 3.) # (1)\n        ray_diagram.add_triangle(triangle, color = 'orange')\n        ray_diagram.add_point(rays[:, 0], color = 'blue')\n        ray_diagram.add_line(rays[:, 0], normals[:, 0], color = 'blue')\n        colors = []\n        for color_id in range(check.shape[1]):\n            if check[0, color_id] == True:\n                colors.append('green')\n            elif check[0, color_id] == False:\n                colors.append('red')\n        ray_diagram.add_point(normals[:, 0], color = colors)\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
  1. Uncomment for running visualization.
  2. Returning intersection normals as new rays, distances from starting point of input rays and a check which returns True if intersection points are inside the triangle.
Why should we be interested in ray and triangle intersections?

Modern Computer Graphics uses various representations for defining three-dimensional objects and scenes. These representations include: * Point Clouds: a series of XYZ coordinates from the surface of a three-dimensional object, * Meshes: a soup of triangles that represents a surface of a three-dimensional object, * Signed Distance Functions: a function informing about the distance between an XYZ point and a surface of a three-dimensional object, * Neural Radiance Fields: A machine learning approach to learning ray patterns from various perspectives. Historically, meshes have been mainly used to represent three-dimensional objects. Thus, intersecting rays and triangles are important for most Computer Graphics.

Challenge: Many triangles!

The example provided above deals with a ray and a batch of rays. However, objects represented with triangles are typically described with many triangles but not one. Note that odak.learn.raytracing.intersect_w_triangle deal with each triangle one by one, and may lead to slow execution times as the function has to visit each triangle one by one. Given the information, we challenge readers to create a new function inside odak.learn.raytracing submodule named intersect_w_mesh. This new function has to be able to work with multiple triangles (meshes) and has to be aware of \"occlusions\" (e.g., a triangle blocking another triangle). In addition, a new unit test, test/test_learn_ray_intersect_w_mesh.py, has to adopt this new function. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_meshes.md.

"},{"location":"course/geometric_optics/#refracting-and-reflecting-rays","title":"Refracting and reflecting rays","text":"

Informative \u00b7 Practical

In the previous subsections, we reviewed ray intersection with various surface representations, including parametric (e.g., spheres) and non-parametric (e.g., meshes). Please remember that raytracing is the most simplistic modeling of light. Thus, often raytracing does not account for any wave or quantum-related nature of light. To our knowledge, light refracts, reflects, or diffracts when light interfaces with a surface or, in other words, a changing medium (e.g., light traveling from air to glass). In that case, our next step should be identifying a methodology to help us model these events using rays. We compiled two utility functions that could help us to model a refraction or a reflection. These functions are named odak.learn.raytracing.refract 1 and odak.learn.raytracing.reflect 1. This first one, odak.learn.raytracing.refract follows Snell's law of refraction, while odak.learn.raytracing.reflect follows a perfect reflection case. We will not go into details of this theory as its simplest form in the way we discuss it could now be considered common knowledge. However, for curious readers, the work by Bell et al. 5 provides a generalized solution for the laws of refraction and reflection. Let us carefully examine these two utility functions to understand their internal workings.

odak.learn.raytracing.refract odak.learn.raytracing.reflect

Definition to refract an incoming ray. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • vector \u2013
             Incoming ray.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n
  • normvector \u2013
             Normal vector.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n
  • n1 \u2013
             Refractive index of the incoming medium.\n
  • n2 \u2013
             Refractive index of the outgoing medium.\n
  • error \u2013
             Desired error.\n

Returns:

  • output ( tensor ) \u2013

    Refracted ray. Expected size is [1, 2, 3]

Source code in odak/learn/raytracing/boundary.py
def refract(vector, normvector, n1, n2, error = 0.01):\n    \"\"\"\n    Definition to refract an incoming ray.\n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    vector         : torch.tensor\n                     Incoming ray.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n    normvector     : torch.tensor\n                     Normal vector.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n    n1             : float\n                     Refractive index of the incoming medium.\n    n2             : float\n                     Refractive index of the outgoing medium.\n    error          : float \n                     Desired error.\n\n    Returns\n    -------\n    output         : torch.tensor\n                     Refracted ray.\n                     Expected size is [1, 2, 3]\n    \"\"\"\n    if len(vector.shape) == 2:\n        vector = vector.unsqueeze(0)\n    if len(normvector.shape) == 2:\n        normvector = normvector.unsqueeze(0)\n    mu    = n1 / n2\n    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2\n    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div\n    b     = (mu ** 2 - 1) / div\n    to    = - b * 0.5 / a\n    num   = 0\n    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2\n    while len(eps[eps > error]) > 0:\n       num   += 1\n       oldto  = to\n       v      = to ** 2 + 2 * a * to + b\n       deltav = 2 * (to + a)\n       to     = to - v / deltav\n       eps    = abs(oldto - to)\n    output = torch.zeros_like(vector)\n    output[:, 0, 0] = normvector[:, 0, 0]\n    output[:, 0, 1] = normvector[:, 0, 1]\n    output[:, 0, 2] = normvector[:, 0, 2]\n    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]\n    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]\n    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]\n    return output\n

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • input_ray \u2013
           A ray or rays.\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n
  • normal \u2013
           A surface normal(s).\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n

Returns:

  • output_ray ( tensor ) \u2013

    Array that contains starting points and cosines of a reflected ray. Expected size is [1 x 2 x 3] or [m x 2 x 3].

Source code in odak/learn/raytracing/boundary.py
def reflect(input_ray, normal):\n    \"\"\" \n    Definition to reflect an incoming ray from a surface defined by a surface normal. \n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    input_ray    : torch.tensor\n                   A ray or rays.\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n    normal       : torch.tensor\n                   A surface normal(s).\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    Returns\n    ----------\n    output_ray   : torch.tensor\n                   Array that contains starting points and cosines of a reflected ray.\n                   Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    \"\"\"\n    if len(input_ray.shape) == 2:\n        input_ray = input_ray.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    mu = 1\n    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8\n    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div\n    a = a.unsqueeze(1)\n    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))\n    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)\n    output_ray[:, 0] = normal[:, 0]\n    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]\n    return output_ray\n

Please note that we provide two refractive indices as inputs in odak.learn.raytracing.refract. These inputs represent the refractive indices of two mediums (e.g., air and glass). However, the refractive index of a medium is dependent on light's wavelength (color). In the following example, where we showcase a sample use case of these utility functions, we will assume that light has a single wavelength. But bear in mind that when you need to ray trace with lots of wavelengths (multi-color RGB or hyperspectral), one must ray trace for each wavelength (color). Thus, the computational complexity of the raytracing increases dramatically as we aim growing realism in the simulations (e.g., describe scenes per color, raytracing for each color). Let's dive deep into how we use these functions in an actual example by observing the example below.

test_learn_ray_refract_reflect.py
import sys\nimport odak\nimport torch\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    starting_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                            no = [5, 5],\n                                                            size = [15., 15.],\n                                                            center = [0., 0., 0.]\n                                                           )\n    end_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                       no = [5, 5],\n                                                       size = [6., 6.],\n                                                       center = [0., 0., 10.]\n                                                      )\n    rays = odak.learn.raytracing.create_ray_from_two_points(\n                                                            starting_points,\n                                                            end_points\n                                                           )\n    triangle = torch.tensor([[\n                              [-5., -5., 10.],\n                              [ 5., -5., 10.],\n                              [ 0.,  5., 10.]\n                            ]])\n    normals, distance, intersecting_rays, intersecting_normals, check = odak.learn.raytracing.intersect_w_triangle(\n                                                                                    rays,\n                                                                                    triangle\n                                                                                   ) \n    n_air = 1.0 # (1)\n    n_glass = 1.51 # (2)\n    refracted_rays = odak.learn.raytracing.refract(intersecting_rays, intersecting_normals, n_air, n_glass) # (3)\n    reflected_rays = odak.learn.raytracing.reflect(intersecting_rays, intersecting_normals) # (4)\n    refract_distance = 11.\n    reflect_distance = 7.2\n    propagated_refracted_rays = odak.learn.raytracing.propagate_ray(\n                                                                    refracted_rays, \n                                                                    torch.ones(refracted_rays.shape[0]) * refract_distance\n                                                                   )\n    propagated_reflected_rays = odak.learn.raytracing.propagate_ray(\n                                                                    reflected_rays,\n                                                                    torch.ones(reflected_rays.shape[0]) * reflect_distance\n                                                                   )\n\n\n\n    visualize = False\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(\n                                                    columns = 2,\n                                                    line_width = 3.,\n                                                    marker_size = 3.,\n                                                    subplot_titles = ['Refraction example', 'Reflection example']\n                                                   ) # (1)\n        ray_diagram.add_triangle(triangle, column = 1, color = 'orange')\n        ray_diagram.add_triangle(triangle, column = 2, color = 'orange')\n        ray_diagram.add_point(rays[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_point(rays[:, 0], column = 2, color = 'blue')\n        ray_diagram.add_line(rays[:, 0], normals[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_line(rays[:, 0], normals[:, 0], column = 2, color = 'blue')\n        ray_diagram.add_line(refracted_rays[:, 0], propagated_refracted_rays[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_line(reflected_rays[:, 0], propagated_reflected_rays[:, 0], column = 2, color = 'blue')\n        colors = []\n        for color_id in range(check.shape[1]):\n            if check[0, color_id] == True:\n                colors.append('green')\n            elif check[0, color_id] == False:\n                colors.append('red')\n        ray_diagram.add_point(normals[:, 0], column = 1, color = colors)\n        ray_diagram.add_point(normals[:, 0], column = 2, color = colors)\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
  1. Refractive index of air (arbitrary and regardless of wavelength), the medium before the ray and triangle intersection.
  2. Refractive index of glass (arbitrary and regardless of wavelength), the medium after the ray and triangle intersection.
  3. Refraction process.
  4. Reflection process.
Challenge: Diffracted intefering rays

This subsection covered simulating refraction and reflection events. However, diffraction or interference 6 is not introduced in this raytracing model. This is because diffraction and interference would require another layer of complication. In other words, rays have to have an extra dimension beyond their starting points and direction cosines, and they also have to have the quality named phase of light. This fact makes a typical ray have dimensions of [1 x 3 x 3] instead of [1 x 2 x 3], where only direction cosines and starting points are defined. Given the information, we challenge readers to create a new submodule, odak.learn.raytracing.diffraction, extending rays to diffraction and interference. In addition, a new set of unit tests should be derived to adopt this new function submodule. To add these to odak, you can rely on the pull request feature on GitHub. You can also create a new engineering note for arbitrary surfaces in docs/notes/raytracing_diffraction_interference.md.

"},{"location":"course/geometric_optics/#optimization-with-rays","title":"Optimization with rays","text":"

Informative \u00b7 Practical

We learned about refraction, reflection, rays, and surface intersections in the previous subsection. We didn't mention it then, but these functions are differentiable 7. In other words, a modern machine learning library can keep a graph of a variable passing through each one of these functions (see chain rule). This differentiability feature is vital because differentiability makes our simulations for light with raytracing based on these functions compatible with modern machine learning frameworks such as Torch. In this subsection, we will use an off-the-shelf optimizer from Torch to optimize variables in our ray tracing simulations. In the first example, we will see that the optimizer helps us define the proper tilt angles for a triangle-shaped mirror and redirect light from a point light source towards a given target. Our first example resembles a straightforward case for optimization by containing only a batch of rays and a single triangle. The problem highlighted in the first example has a closed-form solution, and using an optimizer is obviously overkill. We want our readers to understand that the first example is a warm-up scenario where our readers understand how to interact with race and triangles in the context of an optimization problem. In our following second example, we will deal with a more sophisticated case where a batch of rays arriving from a point light source bounces off a surface with multiple triangles in parentheses mesh and comes at some point in our final target plane. This time we will ask our optimizer to optimize the shape of our triangles so that most of the light bouncing off there's optimized surface ends up at a location close to a target we define in our simulation. This way, we show our readers that a more sophisticated shape could be optimized using our framework, Odak. In real life, the second example could be a lens or mirror shape to be optimized. More specifically, as an application example, it could be a mirror or a lens that focuses light from the Sun onto a solar cell to increase the efficiency of a solar power system, or it could have been a lens helping you to focus on a specific depth given your eye prescription. Let us start from our first example and examine how we can tilt the surfaces using an optimizer, and in this second example, let us see how an optimizer helps us define and optimize shape for a given mesh.

test_learn_ray_optimization.py
import sys\nimport odak\nimport torch\nfrom tqdm import tqdm\n\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    final_surface = torch.tensor([[\n                                   [-5., -5., 0.],\n                                   [ 5., -5., 0.],\n                                   [ 0.,  5., 0.]\n                                 ]])\n    final_target = torch.tensor([[3., 3., 0.]])\n    triangle = torch.tensor([\n                             [-5., -5., 10.],\n                             [ 5., -5., 10.],\n                             [ 0.,  5., 10.]\n                            ])\n    starting_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                            no = [5, 5],\n                                                            size = [1., 1.],\n                                                            center = [0., 0., 0.]\n                                                           )\n    end_point = odak.learn.raytracing.center_of_triangle(triangle)\n    rays = odak.learn.raytracing.create_ray_from_two_points(\n                                                            starting_points,\n                                                            end_point\n                                                           )\n    angles = torch.zeros(1, 3, requires_grad = True)\n    learning_rate = 2e-1\n    optimizer = torch.optim.Adam([angles], lr = learning_rate)\n    loss_function = torch.nn.MSELoss()\n    number_of_steps = 100\n    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)\n    for step in t:\n        optimizer.zero_grad()\n        rotated_triangle, _, _, _ = odak.learn.tools.rotate_points(\n                                                                   triangle, \n                                                                   angles = angles, \n                                                                   origin = end_point\n                                                                  )\n        _, _, intersecting_rays, intersecting_normals, check = odak.learn.raytracing.intersect_w_triangle(\n                                                                                                          rays,\n                                                                                                          rotated_triangle\n                                                                                                         )\n        reflected_rays = odak.learn.raytracing.reflect(intersecting_rays, intersecting_normals)\n        final_normals, _ = odak.learn.raytracing.intersect_w_surface(reflected_rays, final_surface)\n        if step == 0:\n            start_rays = rays.detach().clone()\n            start_rotated_triangle = rotated_triangle.detach().clone()\n            start_intersecting_rays = intersecting_rays.detach().clone()\n            start_intersecting_normals = intersecting_normals.detach().clone()\n            start_final_normals = final_normals.detach().clone()\n        final_points = final_normals[:, 0]\n        target = final_target.repeat(final_points.shape[0], 1)\n        loss = loss_function(final_points, target)\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description('Loss: {}'.format(loss.item()))\n    print('Loss: {}, angles: {}'.format(loss.item(), angles))\n\n\n    visualize = False\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(\n                                                    columns = 2,\n                                                    line_width = 3.,\n                                                    marker_size = 3.,\n                                                    subplot_titles = [\n                                                                       'Surace before optimization', \n                                                                       'Surface after optimization',\n                                                                       'Hits at the target plane before optimization',\n                                                                       'Hits at the target plane after optimization',\n                                                                     ]\n                                                   ) \n        ray_diagram.add_triangle(start_rotated_triangle, column = 1, color = 'orange')\n        ray_diagram.add_triangle(rotated_triangle, column = 2, color = 'orange')\n        ray_diagram.add_point(start_rays[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_point(rays[:, 0], column = 2, color = 'blue')\n        ray_diagram.add_line(start_intersecting_rays[:, 0], start_intersecting_normals[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_line(intersecting_rays[:, 0], intersecting_normals[:, 0], column = 2, color = 'blue')\n        ray_diagram.add_line(start_intersecting_normals[:, 0], start_final_normals[:, 0], column = 1, color = 'blue')\n        ray_diagram.add_line(start_intersecting_normals[:, 0], final_normals[:, 0], column = 2, color = 'blue')\n        ray_diagram.add_point(final_target, column = 1, color = 'red')\n        ray_diagram.add_point(final_target, column = 2, color = 'green')\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n

Let us also look into the more sophisticated second example, where a triangular mesh is optimized to meet a specific demand, redirecting rays to a particular target.

test_learn_ray_mesh.py
import sys\nimport odak\nimport torch\nfrom tqdm import tqdm\n\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    device = torch.device('cpu')\n    final_target = torch.tensor([-2., -2., 10.], device = device)\n    final_surface = odak.learn.raytracing.define_plane(point = final_target)\n    mesh = odak.learn.raytracing.planar_mesh(\n                                             size = torch.tensor([1.1, 1.1]), \n                                             number_of_meshes = torch.tensor([9, 9]), \n                                             device = device\n                                            )\n    start_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                         no = [11, 11],\n                                                         size = [1., 1.],\n                                                         center = [2., 2., 10.]\n                                                        )\n    end_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                       no = [11, 11],\n                                                       size = [1., 1.],\n                                                       center = [0., 0., 0.]\n                                                      )\n    start_points = start_points.to(device)\n    end_points = end_points.to(device)\n    loss_function = torch.nn.MSELoss(reduction = 'sum')\n    learning_rate = 2e-3\n    optimizer = torch.optim.AdamW([mesh.heights], lr = learning_rate)\n    rays = odak.learn.raytracing.create_ray_from_two_points(start_points, end_points)\n    number_of_steps = 100\n    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)\n    for step in t:\n        optimizer.zero_grad()\n        triangles = mesh.get_triangles()\n        reflected_rays, reflected_normals = mesh.mirror(rays)\n        final_normals, _ = odak.learn.raytracing.intersect_w_surface(reflected_rays, final_surface)\n        final_points = final_normals[:, 0]\n        target = final_target.repeat(final_points.shape[0], 1)\n        if step == 0:\n            start_triangles = triangles.detach().clone()\n            start_reflected_rays = reflected_rays.detach().clone()\n            start_final_normals = final_normals.detach().clone()\n        loss = loss_function(final_points, target)\n        loss.backward(retain_graph = True)\n        optimizer.step() \n        description = 'Loss: {}'.format(loss.item())\n        t.set_description(description)\n    print(description)\n\n\n    visualize = False\n    if visualize:\n        ray_diagram = odak.visualize.plotly.rayshow(\n                                                    rows = 1,\n                                                    columns = 2,\n                                                    line_width = 3.,\n                                                    marker_size = 1.,\n                                                    subplot_titles = ['Before optimization', 'After optimization']\n                                                   ) \n        for triangle_id in range(triangles.shape[0]):\n            ray_diagram.add_triangle(\n                                     start_triangles[triangle_id], \n                                     row = 1, \n                                     column = 1, \n                                     color = 'orange'\n                                    )\n            ray_diagram.add_triangle(triangles[triangle_id], row = 1, column = 2, color = 'orange')\n        html = ray_diagram.save_offline()\n        markdown_file = open('{}/ray.txt'.format(output_directory), 'w')\n        markdown_file.write(html)\n        markdown_file.close()\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n
Challenge: Differentiable detector

In our examples, where we try bouncing light towards a fixed target, our target is defined as a single point along XYZ axes. However, in many cases in Optics and Computer Graphics, we may want to design surfaces to resemble a specific distribution of intensities over a plane (e.g., a detector or a camera sensor). For example, the work by Schwartzburg et al. 8 designs optical surfaces such that when light refracts, the distribution of these intensities forms an image at some target plane. To be able to replicate such works with Odak, odak needs a detector that is differentiable. This detector could be added as a class in the odak.learn.raytracing submodule, and a new unit test could be added as test/test_learn_detector.py. To add these to odak, you can rely on the pull request feature on GitHub.

"},{"location":"course/geometric_optics/#rendering-scenes","title":"Rendering scenes","text":"

Informative \u00b7 Practical

This section shows how one can use raytracing for rendering purposes in Computer Graphics. Note that the provided example is simple, aiming to introduce a newcomer to how raytracing could be used for rendering purposes. The example uses a single perspective camera and relies on a concept called splatting, where rays originate from a camera towards a scene. The scene is composed of randomly colored triangles, and each time a ray hits a random colored triangle, our perspective camera's corresponding pixel is painted with the color of that triangle. Let us review our simple example by reading the code and observing its outcome.

test_learn_ray_render.py
import sys\nimport odak\nimport torch\nfrom tqdm import tqdm\n\ndef test(output_directory = 'test_output'):\n    odak.tools.check_directory(output_directory)\n    final_surface_point = torch.tensor([0., 0., 10.])\n    final_surface = odak.learn.raytracing.define_plane(point = final_surface_point)\n    no = [500, 500]\n    start_points, _, _, _ = odak.learn.tools.grid_sample(\n                                                         no = no,\n                                                         size = [10., 10.],\n                                                         center = [0., 0., -10.]\n                                                        )\n    end_point = torch.tensor([0., 0., 0.])\n    rays = odak.learn.raytracing.create_ray_from_two_points(start_points, end_point)\n    mesh = odak.learn.raytracing.planar_mesh(\n                                             size = torch.tensor([10., 10.]),\n                                             number_of_meshes = torch.tensor([40, 40]),\n                                             angles = torch.tensor([  0., -70., 0.]),\n                                             offset = torch.tensor([ -2.,   0., 5.]),\n                                            )\n    triangles = mesh.get_triangles()\n    play_button = torch.tensor([[\n                                 [  1.,  0.5, 3.],\n                                 [  0.,  0.5, 3.],\n                                 [ 0.5, -0.5, 3.],\n                                ]])\n    triangles = torch.cat((play_button, triangles), dim = 0)\n    background_color = torch.rand(3)\n    triangles_color = torch.rand(triangles.shape[0], 3)\n    image = torch.zeros(rays.shape[0], 3) \n    for triangle_id, triangle in enumerate(triangles):\n        _, _, _, _, check = odak.learn.raytracing.intersect_w_triangle(rays, triangle)\n        check = check.squeeze(0).unsqueeze(-1).repeat(1, 3)\n        color = triangles_color[triangle_id].unsqueeze(0).repeat(check.shape[0], 1)\n        image[check == True] = color[check == True] * check[check == True]\n    image[image == [0., 0., 0]] = background_color\n    image = image.view(no[0], no[1], 3)\n    odak.learn.tools.save_image('{}/image.png'.format(output_directory), image, cmin = 0., cmax = 1.)\n    assert True == True\n\n\nif __name__ == '__main__':\n    sys.exit(test())\n

Rendered result for the renderer script of \"/test/test_learn_ray_render.py\".

A modern raytracer used in gaming is far more sophisticated than the example we provide here. There aspects such as material properties or tracing the ray from its source to a camera or allowing rays to interface with multiple materials. Covering these aspects in a crash course like the one we provide here will take much work. Instead, we suggest our readers follow the resources provided in other classes, references provided at the end, or any other online available materials.

"},{"location":"course/geometric_optics/#conclusion","title":"Conclusion","text":"

Informative

We can simulate light on a computer using various methods. We explain \"raytracing\" as one of these methods. Often, raytracing deals with light intensities, omitting many other aspects of light, like the phase or polarization of light. In addition, sending the right amount of rays from a light source into a scene in raytracing is always a struggle as an outstanding sampling problem. Raytracing creates many success stories in gaming (e.g., NVIDIA RTX or AMD Radeon Rays) and optical component design (e.g., Zemax or Ansys Speos).

Overall, we cover a basic introduction to how to model light as rays and how to use rays to optimize against a given target. Note that our examples resemble simple cases. This section aims to provide the readers with a suitable basis to get started with the raytracing of light in simulations. A dedicated and motivated reader could scale up from this knowledge to advance concepts in displays, cameras, visual perception, optical computing, and many other light-based applications.

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

  1. GH Spencer and MVRK Murty. General ray-tracing procedure. JOSA, 52(6):672\u2013678, 1962.\u00a0\u21a9\u21a9\u21a9

  2. Peter Shirley. Ray tracing in one weekend. Amazon Digital Services LLC, 1:4, 2018.\u00a0\u21a9

  3. Morgan McGuire. The graphics codex. 2018.\u00a0\u21a9

  4. Wenzel Jakob, S\u00e9bastien Speierer, Nicolas Roussel, and Delio Vicini. Dr. jit: a just-in-time compiler for differentiable rendering. ACM Transactions on Graphics (TOG), 41(4):1\u201319, 2022.\u00a0\u21a9

  5. Robert J Bell, Kendall R Armstrong, C Stephen Nichols, and Roger W Bradley. Generalized laws of refraction and reflection. JOSA, 59(2):187\u2013189, 1969.\u00a0\u21a9

  6. Max Born and Emil Wolf. Principles of optics: electromagnetic theory of propagation, interference and diffraction of light. Elsevier, 2013.\u00a0\u21a9

  7. Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. NIPS 2017 Workshop Autodiff, 2017.\u00a0\u21a9

  8. Yuliy Schwartzburg, Romain Testuz, Andrea Tagliasacchi, and Mark Pauly. High-contrast computational caustic design. ACM Transactions on Graphics (TOG), 33(4):1\u201311, 2014.\u00a0\u21a9

"},{"location":"course/visual_perception/","title":"Visual Perception","text":""},{"location":"course/visual_perception/#color-perception","title":"Color Perception","text":"

Informative \u00b7 Practical

We can establish an understanding on color perception through studying its physical and perceptual meaning. This way, we can gather more information on its relation to technologies and devices including displays, cameras, sensors, communication devices, computers and computer graphics.

Color, a perceptual phenomenon, can be explained in a physical and visual perception capacity. In the physical sense, color is a quantity representing the response to wavelength of light. The human visual system can perceive colors within a certain range of the electromagnetic spectrum, from around 400 nanometers to 700 nanometers. For greater details on the electromagnetic spectrum and concept of wavelength, we recommend revisiting Light, Computation, and Computational Light section of our course. For the human visual system, color is a perceptual phenomenon created by our brain when specific wavelengths of light are emitted, reflected, or transmitted by objects. The perception of color originates from the absorption of light by photoreceptors in the eye. These photoreceptor cells convert the light into electrical signals to be interpreted by the brain1. Here, you can see a close-up photograph of these photoreceptor cells found in the eye.

Micrograph of retinal photoreceptor cells, with rods and cones highlighted in green (top row). Image courtesy of NIH, licensed under CC PDM 1.0. View source.

The photoreceptors, where color perception originates, are called rods and cones2. Here, we provide a sketch showing where these rods and cones are located inside the eye. By closely observing this sketch, you can also understand the basic average geometry of a human eye and its parts helping to redirect light from an actual scene towards retinal cells.

Anatomy of an Eye (Designed with BioRender.com).

Rods, which are relatively more common in the periphery, help people see in low-light (scotopic) conditions. The current understanding is that the roids can only interpret in a greyscale manner. Cones, which are more dense in the fovea, are pivotal in color perception in brighter (photopic) environments. We highlight the distribution of these photoreceptor cells, rods and cones with changing eccentricities in the eye. Here, the word eccentricities refer to angles with respect to our gaze direction. For instance, if a person is not directly gazing at a location or an object in a given scene, that location or the object would have some angle to the gaze of that person. Thus, there would be at some angles, some eccentricity between the gaze of that person and that location or object in that scene.

Retinal Photoreceptor Distribution, adapted from the work by Goldstein et al [3].

In the above sketch, we introduced various parts on the retina, including fovea, parafovea, perifovea and peripheral vision. Note that these regions are defined by the angles, in other words eccentricities. Please also note that there is a region on our retina where there are no rods and cones are available. This region could be found in every human eye and known as the blind spot on the retina. Visual acuity and contrast sensitivity decreases progressively across these identified regions, with the most detail in the fovea, diminishing toward the periphery.

Spectral Sensitivities of LMS cones

The cones are categorized into three types based on their sensitivity to specific wavelengths of light, corresponding to long (L), medium (M), and short (S) wavelength cones. These three types of cones3 allow us to better understand the trichromatic theory4, suggesting that human color perception stems from combining stimulations of the LMS cones. Scientists have tried to graphically represent how sensitive each type of cone is to different wavelengths of light, which is known as the spectral sensitivity function5. In practical applications such as display technologies and computational imaging, the LMS cone response can be replicated with the following formula:

\\[ LMS = \\sum_{i=1}^{3} \\text{RGB}_i \\cdot \\text{Spectrum}_i \\cdot \\text{Sensitivity}_i \\]

Where:

  • \\(RGB_i\\): The i-th color channel (Red, Green, or Blue) of the image.
  • \\(Spectrum_i\\): The spectral distribution of the corresponding primary
  • \\(Sensitivity_i\\): The sensitivity of the L, M, and S cones for each wavelength.

This formula gives us more insight on how we percieve colors from different digital and physical inputs.

Looking for more reading to expand your understanding on human visual system?

We recommend these papers, which we find it insightful: - B. P. Schmidt, M. Neitz, and J. Neitz, \"Neurobiological hypothesis of color appearance and hue perception,\" J. Opt. Soc. Am. A 31(4), A195\u2013207 (2014) - Biomimetic Eye Modeling & Deep Neuromuscular Oculomotor Control

The story of color perception only deepens with the concept of color opponency6. This theory reveals that our perception of color is not just a matter of additive combinations of primary colors but also involves a dynamic interplay of opposing colors: red versus green, blue versus yellow. This phenomenon is rooted in the neural pathways of the eye and brain, where certain cells are excited or inhibited by specific wavelengths, enhancing our ability to distinguish between subtle shades and contrasts. Below is a mathematical formulation for the color opponency model proposed by Schmidt et al.3

\\[\\begin{bmatrix} I_{(M+S)-L} \\\\ I_{(L+S)-M} \\\\ I_{(L+M+S)} \\end{bmatrix} = \\begin{bmatrix} (I_M + I_S) - I_L \\\\ (I_L + I_S) - I_M \\\\ (I_L, I_M, I_S) \\end{bmatrix}\\]

In this equation, \\(I_L\\), \\(I_M\\), and \\(I_S\\) represent the intensities received by the long, medium, and short cone cells, respectively. Opponent signals are represented by the differences between combinations of cone responses.

We could exercise on our understanding of trichromat sensation with LMS cones and the concept of color opponency by vising the functions available in our toolkit, odak. The utility function we will review is odak.learn.perception.display_color_hvs.primarier_to_lms() from odak.learn.perception. Let us use this test to demonstrate how we can obtain LMS sensation from the color primaries of an image.

test_learn_perception_display_color_hvs.py
import odak # (1)\nimport torch\nimport sys\nfrom odak.learn.perception.color_conversion import display_color_hvs\n\n\ndef test(\n         device = torch.device('cpu'),\n         output_directory = 'test_output'\n        ):\n    odak.tools.check_directory(output_directory)\n    torch.manual_seed(0)\n\n    image_rgb = odak.learn.tools.load_image(\n                                            'test/data/fruit_lady.png',\n                                            normalizeby = 255.,\n                                            torch_style = True\n                                           ).unsqueeze(0).to(device) # (2)\n\n    the_number_of_primaries = 3\n    multi_spectrum = torch.zeros(\n                                 the_number_of_primaries,\n                                 301\n                                ) # (3)\n    multi_spectrum[0, 200:250] = 1.\n    multi_spectrum[1, 130:145] = 1.\n    multi_spectrum[2, 0:50] = 1.\n\n    display_color = display_color_hvs(\n                                      read_spectrum ='tensor',\n                                      primaries_spectrum=multi_spectrum,\n                                      device = device\n                                     ) # (4)\n\n    image_lms_second_stage = display_color.primaries_to_lms(image_rgb) # (5)\n    image_lms_third_stage = display_color.second_to_third_stage(image_lms_second_stage) # (6)\n\n\n    odak.learn.tools.save_image(\n                                '{}/image_rgb.png'.format(output_directory),\n                                image_rgb,\n                                cmin = 0.,\n                                cmax = image_rgb.max()\n                               )\n\n\n    odak.learn.tools.save_image(\n                                '{}/image_lms_second_stage.png'.format(output_directory),\n                                image_lms_second_stage,\n                                cmin = 0.,\n                                cmax = image_lms_second_stage.max()\n                               )\n\n    odak.learn.tools.save_image(\n                                '{}/image_lms_third_stage.png'.format(output_directory),\n                                image_lms_third_stage,\n                                cmin = 0.,\n                                cmax = image_lms_third_stage.max()\n                               )\n\n\n    image_rgb_noisy = image_rgb * 0.6 + torch.rand_like(image_rgb) * 0.4 # (7)\n    loss_lms = display_color(image_rgb, image_rgb_noisy) # (8)\n    print('The third stage LMS sensation difference between two input images is {:.10f}.'.format(loss_lms))\n    assert True == True\n\nif __name__ == \"__main__\":\n    sys.exit(test())\n
  1. Adding odak to our imports.
  2. Loading an existing RGB image.
  3. Defining the spectrum of our primaries of our imaginary display. These values are defined for each primary from 400 nm to 701 nm (301 elements).
  4. Obtain LMS cone sensations for our primaries of our imaginary display.
  5. Calculating the LMS sensation of our input RGB image at the second stage of color perception using our imaginary display.
  6. Calculating the LMS sensation of our input RGB image at the third stage of color perception using our imaginary display.
  7. We are intentionally adding some noise to the input RGB image here.
  8. We calculate the perceptual loss/difference between the two input image (original RGB vs noisy RGB). This a visualization of a randomly generated image and its' LMS cone sensation.

Our code above saves three different images. The very first saved image is the ground truth RGB image as depicted below.

Original ground truth image.

We process this ground truth image by accounting human visual system's cones and display backlight spectrum. This way, we can calculate how our ground truth image is sensed by LMS cones. The LMS sensation, in other words, ground truth image in LMS color space is provided below. Note that each color here represent a different cone, for instance, green color channel of below image represents medium cone and blue channel represents short cones. Keep in mind that LMS sensation is also known as trichromat sensation in the literature.

Image in LMS cones trichromat space.

Earlier, we discussed about the color oppenency theory. We follow this theory, and with our code, we utilize trichromat values to derive an image representation below.

Image representation of color opponency. Lab work: Observing the effect of display spectrum

We introduce our unit test, test_learn_perception_display_color_hvs.py, to provide an example on how to convert an RGB image to trichromat values as sensed by the retinal cone cells. Note that during this exercise, we define a variable named multi_spectrum to represent the wavelengths of our each color primary. These wavelength values are stored in a vector for each primary and provided the intensity of a corresponding wavelength from 400 nm to 701 nm. The trichromat values that we have derived from our original ground truth RGB image is highly correlated with these spectrum values. To observe this correlation, we encourage you to find spectrums of actual display types (e.g., OLEDs, LEDs, LCDs) and map the multi_spectrum to their spectrum to observe the difference in color perception in various display technologies. In addition, we believe that this will also give you a practical sandbox to examine the correlation between wavelengths and trichromat values.

"},{"location":"course/visual_perception/#closing-remarks","title":"Closing remarks","text":"

As we dive deeper into light and color perception, it becomes evident that the task of replicating the natural spectrum of colors in technology is still an evolving journey. This exploration into the nature of color sets the stage for a deeper examination of how our biological systems perceive color and how technology strives to emulate that perception.

Consider revisiting this chapter

Remember that you can always revisit this chapter as you progress with the course and as you need it. This chapter is vital for establishing a means to complete your assignments and could help formulate a suitable base to collaborate and work with my research group in the future or other experts in the field.

Reminder

We host a Slack group with more than 300 members. This Slack group focuses on the topics of rendering, perception, displays and cameras. The group is open to public and you can become a member by following this link. Readers can get in-touch with the wider community using this public group.

  1. Jeremy Freeman and Eero P Simoncelli. Metamers of the ventral stream. Nature Neuroscience, 14:1195\u20131201, 2011. doi:10.1038/nn.2889.\u00a0\u21a9

  2. Trevor D Lamb. Why rods and cones? Eye, 30:179\u2013185, 2015. doi:10.1038/eye.2015.236.\u00a0\u21a9

  3. Brian P Schmidt, Maureen Neitz, and Jay Neitz. Neurobiological hypothesis of color appearance and hue perception. Journal of the Optical Society of America A, 31(4):A195\u2013A207, 2014. doi:10.1364/JOSAA.31.00A195.\u00a0\u21a9\u21a9

  4. H. V. Walters. Some experiments on the trichromatic theory of vision. Proceedings of the Royal Society of London. Series B - Biological Sciences, 131:27\u201350, 1942. doi:10.1098/rspb.1942.0016.\u00a0\u21a9

  5. Andrew Stockman and Lindsay T Sharpe. The spectral sensitivities of the middle- and long-wavelength-sensitive cones derived from measurements in observers of known genotype. Vision Research, 40:1711\u20131737, 2000. doi:10.1016/S0042-6989(00)00021-3.\u00a0\u21a9

  6. Steven K Shevell and Paul R Martin. Color opponency: tutorial. Journal of the Optical Society of America A, 34(8):1099\u20131110, 2017. doi:10.1364/JOSAA.34.001099.\u00a0\u21a9

"},{"location":"notes/holographic_light_transport/","title":"Holographic light transport","text":"

Odak contains essential ingredients for research and development targeting Computer-Generated Holography. We consult the beginners in this matter to Goodman's Introduction to Fourier Optics book (ISBN-13: 978-0974707723) and Principles of optics: electromagnetic theory of propagation, interference and diffraction of light from Max Born and Emil Wolf (ISBN 0-08-26482-4). This engineering note will provide a crash course on how light travels from a phase-only hologram to an image plane.

Holographic image reconstruction. A collimated beam with a homogenous amplitude distribution (A=1) illuminates a phase-only hologram \\(u_0(x,y)\\). Light from this hologram diffracts and arrive at an image plane \\(u(x,y)\\) at a distance of z. Diffracted beams from each hologram pixel interfere at the image plane and, finally, reconstruct a target image.

As depicted in above figure, when such holograms are illuminated with a collimated coherent light (e.g. laser), these holograms can reconstruct an intended optical field at target depth levels. How light travels from a hologram to a parallel image plane is commonly described using Rayleigh-Sommerfeld diffraction integrals (For more, consult Heurtley, J. C. (1973). Scalar Rayleigh\u2013Sommerfeld and Kirchhoff diffraction integrals: a comparison of exact evaluations for axial points. JOSA, 63(8), 1003-1008.). The first solution of the Rayleigh-Sommerfeld integral, also known as the Huygens-Fresnel principle, is expressed as follows:

\\(u(x,y)=\\frac{1}{j\\lambda} \\int\\!\\!\\!\\!\\int u_0(x,y)\\frac{e^{jkr}}{r}cos(\\theta)dxdy,\\)

where field at a target image plane, \\(u(x,y)\\), is calculated by integrating over every point of hologram's field, \\(u_0(x,y)\\). Note that, for the above equation, \\(r\\) represents the optical path between a selected point over a hologram and a selected point in the image plane, theta represents the angle between these two points, k represents the wavenumber (\\(\\frac{2\\pi}{\\lambda}\\)) and \\(\\lambda\\) represents the wavelength of light. In this described light transport model, optical fields, \\(u_0(x,y)\\) and \\(u(x,y)\\), are represented with a complex value,

\\(u_0(x,y)=A(x,y)e^{j\\phi(x,y)},\\)

where A represents the spatial distribution of amplitude and \\(\\phi\\) represents the spatial distribution of phase across a hologram plane. The described holographic light transport model is often simplified into a single convolution with a fixed spatially invariant complex kernel, \\(h(x,y)\\) (Sypek, Maciej. \"Light propagation in the Fresnel region. New numerical approach.\" Optics communications 116.1-3 (1995): 43-48.).

\\(u(x,y)=u_0(x,y) * h(x,y) =\\mathcal{F}^{-1}(\\mathcal{F}(u_0(x,y)) \\mathcal{F}(h(x,y)))\\)

There are multiple variants of this simplified approach:

  • Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.,
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics letters 45.6 (2020): 1543-1546.,
  • Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419.

In many cases, people choose to use the most common form of h described as

\\(h(x,y)=\\frac{e^{jkz}}{j\\lambda z} e^{\\frac{jk}{2z} (x^2+y^2)},\\)

where z represents the distance between a hologram plane and a target image plane. Note that beam propagation can also be learned for physical setups to avoid imperfections in a setup and to improve the image quality at an image plane:

  • Peng, Yifan, et al. \"Neural holography with camera-in-the-loop training.\" ACM Transactions on Graphics (TOG) 39.6 (2020): 1-14.,
  • Chakravarthula, Praneeth, et al. \"Learned hardware-in-the-loop phase retrieval for holographic near-eye displays.\" ACM Transactions on Graphics (TOG) 39.6 (2020): 1-18.,
  • Kavakl\u0131, Koray, Hakan Urey, and Kaan Ak\u015fit. \"Learned holographic light transport.\" Applied Optics (2021)..
"},{"location":"notes/holographic_light_transport/#see-also","title":"See also","text":"

For more engineering notes, follow:

  • Computer Generated-Holography
"},{"location":"notes/optimizing_holograms_using_odak/","title":"Optimizing holograms using Odak","text":"

This engineering note will give you an idea about how to optimize phase-only holograms using Odak. We consult the beginners in this matter to Goodman's Introduction to Fourier Optics (ISBN-13: 978-0974707723) and Principles of optics: electromagnetic theory of propagation, interference and diffraction of light from Max Born and Emil Wolf (ISBN 0-08-26482-4). Note that the creators of this documentation are from the Computational Displays domain. However, the provided submodules can potentially aid other lines of research as well, such as Computational Imaging or Computational Microscopy.

The optimization that is referred to in this document is the one that generates a phase-only hologram that can reconstruct a target image. There are multiple ways in the literature to optimize a phase-only hologram for a single plane, and these include:

Gerchberg-Saxton and Yang-Yu algorithms: - Yang, G. Z., Dong, B. Z., Gu, B. Y., Zhuang, J. Y., & Ersoy, O. K. (1994). Gerchberg\u2013Saxton and Yang\u2013Gu algorithms for phase retrieval in a nonunitary transform system: a comparison. Applied optics, 33(2), 209-218.

Stochastic Gradient Descent based optimization: - Chen, Y., Chi, Y., Fan, J., & Ma, C. (2019). Gradient descent with random initialization: Fast global convergence for nonconvex phase retrieval. Mathematical Programming, 176(1), 5-37.

Odak provides functions to optimize phase-only holograms using Gerchberg-Saxton algorithm or the Stochastic Gradient Descent based approach. The relevant functions here are odak.learn.wave.stochastic_gradient_descent and odak.learn.wave.gerchberg_saxton. We will review both of these definitions in this document. But first, let's get prepared.

"},{"location":"notes/optimizing_holograms_using_odak/#preparation","title":"Preparation","text":"

We first start with imports, here is all you need:

from odak.learn.wave import stochastic_gradient_descent, calculate_amplitude, calculate_phase\nimport torch\n

We will also be needing some variables that defines the wavelength of light that we work with:

wavelength = 0.000000532\n

Pixel pitch and resolution of the phase-only hologram or a phase-only spatial light modulator that we are simulating:

dx = 0.0000064\nresolution = [1080, 1920]\n

Define the distance that the light will travel from optimized hologram.

distance = 0.15\n

We have to set a target image. You can either load a sample image here or paint a white rectangle on a white background like in this example.

target = torch.zeros(resolution[0],resolution[1])\ntarget[500:600,400:450] = 1.\n

Surely, we also have to set the number of iterations and learning rate for our optimizations. If you want the GPU support, you also have to set the cuda as True. Propagation type has to be defined as well. In this example, we will use transfer function Fresnel approach. For more on propagation types, curious readers can consult Computational Fourier Optics David Vuelz (ISBN13:9780819482044).

iteration_number = 100\nlearning_rate = 0.1\ncuda = True\npropagation_type = 'TR Fresnel'\n

This step concludes our preparations. Let's dive into optimizing our phase-only holograms. Depending on your choice, you can either optimize using Gerchberg-Saxton approach or the Stochastic Gradient Descent approach. This document will only show you Stochastic Gradient Descent approach as it is the state of art. However, optimizing a phase-only hologram is as importing:

from odak.learn.wave import gerchberg_saxton\n

and almost as easy as replacing stochastic_gradient_descent with gerchberg_saxton in the upcoming described hologram routine. For greater details, consult to documentation of odak.learn.wave.

"},{"location":"notes/optimizing_holograms_using_odak/#stochastic-gradient-descent-approach","title":"Stochastic Gradient Descent approach","text":"

We have prepared a function for you to avoid compiling a differentiable hologram optimizer from scratch.

hologram, reconstructed = stochastic_gradient_descent(\n        target,\n        wavelength,\n        distance,\n        dx,\n        resolution,\n        'TR Fresnel',\n        iteration_number,\n        learning_rate=learning_rate,\n        cuda=cuda\n    )\n
Iteration: 99 loss:0.0003\n

Congratulations! You have just optimized a phase-only hologram that reconstruct your target image at the target depth.

Surely, you want to see what kind of image is being reconstructed with this newly optimized hologram. You can save the outcome to an image file easily. Odak provides tools to save and load images. First, you have to import:

from odak.learn.tools import save_image,load_image\n

As you can recall, we have created a target image earlier that is normalized between zero and one. The same is true for our result, reconstructed. Therefore, we have to save it correctly by taking that into account. Note that reconstructed is the complex field generated by our optimized hologram variable. So, we need to save the reconstructed intensity as humans and cameras capture intensity but not a complex field with phase and amplitude.

reconstructed_intensity = calculate_amplitude(reconstructed)**2\nsave_image('reconstructed_image.png',reconstructed_intensity,cmin=0.,cmax=1.)\n
True\n

To save our hologram as an image so that we can load it to a spatial light modulator, we have to normalize it between zero and 255 (dynamic range of a typical image on a computer).

P.S. Depending on your SLM's calibration and dynamic range things may vary.

slm_range = 2*3.14\ndynamic_range = 255\nphase_hologram = calculate_phase(hologram)\nphase_only_hologram = (phase_hologram%slm_range)/(slm_range)*dynamic_range\n

It is now time for saving our hologram:

save_image('phase_only_hologram.png',phase_only_hologram)\n
True\n

In some cases, you may want to add a grating term to your hologram as you will display it on a spatial light modulator. There are various reasons for that, but the most obvious is getting rid of zeroth-order reflections that are not modulated by your hologram. In case you need it is as simple as below:

from odak.learn.wave import linear_grating\ngrating = linear_grating(resolution[0],resolution[1],axis='y').to(phase_hologram.device)\nphase_only_hologram_w_grating = phase_hologram+calculate_phase(grating)\n

And let's save what we got from this step:

phase_only_hologram_w_grating = (phase_only_hologram_w_grating%slm_range)/(slm_range)*dynamic_range\nsave_image('phase_only_hologram_w_grating.png',phase_only_hologram_w_grating)\n
True\n
"},{"location":"notes/optimizing_holograms_using_odak/#see-also","title":"See also","text":"

For more engineering notes, follow:

  • Computer Generated-Holography
"},{"location":"notes/using_metameric_loss/","title":"Using metameric loss","text":"

This engineering note will give you an idea about using the metameric perceptual loss in odak. This note is compiled by David Walton. If you have further questions regarding this note, please email David at david.walton.13@ucl.ac.uk.

Our metameric loss function works in a very similar way to built in loss functions in pytorch, such as torch.nn.MSELoss(). However, it has a number of parameters which can be adjusted on creation (see the documentation). Additionally, when calculating the loss a gaze location must be specified. For example:

loss_func = odak.learn.perception.MetamericLoss()\nloss = loss_func(my_image, gt_image, gaze=[0.7, 0.3])\n

The loss function caches some information, and performs most efficiently when repeatedly calculating losses for the same image size, with the same gaze location and foveation settings.

We recommend adjusting the parameters of the loss function to match your application. Most importantly, please set the real_image_width and real_viewing_distance parameters to correspond to how your image will be displayed to the user. The alpha parameter controls the intensity of the foveation effect. You should only need to set alpha once - you can then adjust the width and viewing distance to achieve the same apparent foveation effect on a range of displays & viewing conditions. Note that we assume the pixels in the displayed image are square, and derive the height from the image dimensions.

We also provide two baseline loss functions BlurLoss and MetamerMSELoss which function in much the same way.

At the present time the loss functions are implemented only for images displayed to a user on a flat 2D display (e.g. an LCD computer monitor). Support for equirectangular 3D images is planned for the future.

"},{"location":"notes/using_metameric_loss/#see-also","title":"See also","text":"

Visual perception

"},{"location":"odak/fit/","title":"odak.fit","text":"

odak.fit

Provides functions to fit models to a provided data. These functions could be best described as a catalog of machine learning models.

"},{"location":"odak/fit/#odak.fit.gradient_descent_1d","title":"gradient_descent_1d(input_data, ground_truth_data, parameters, function, gradient_function, loss_function, learning_rate=0.1, iteration_number=10)","text":"

Vanilla Gradient Descent algorithm for 1D data.

Parameters:

  • input_data \u2013
                One-dimensional input data.\n
  • ground_truth_data (array) \u2013
                One-dimensional ground truth data.\n
  • parameters \u2013
                Parameters to be optimized.\n
  • function \u2013
                Function to estimate an output using the parameters.\n
  • gradient_function (function) \u2013
                Function used in estimating gradient to update parameters at each iteration.\n
  • learning_rate \u2013
                Learning rate.\n
  • iteration_number \u2013
                Iteration number.\n

Returns:

  • parameters ( array ) \u2013

    Optimized parameters.

Source code in odak/fit/__init__.py
def gradient_descent_1d(\n                        input_data,\n                        ground_truth_data,\n                        parameters,\n                        function,\n                        gradient_function,\n                        loss_function,\n                        learning_rate = 1e-1,\n                        iteration_number = 10\n                       ):\n    \"\"\"\n    Vanilla Gradient Descent algorithm for 1D data.\n\n    Parameters\n    ----------\n    input_data        : numpy.array\n                        One-dimensional input data.\n    ground_truth_data : numpy.array\n                        One-dimensional ground truth data.\n    parameters        : numpy.array\n                        Parameters to be optimized.\n    function          : function\n                        Function to estimate an output using the parameters.\n    gradient_function : function\n                        Function used in estimating gradient to update parameters at each iteration.\n    learning_rate     : float\n                        Learning rate.\n    iteration_number  : int\n                        Iteration number.\n\n\n    Returns\n    -------\n    parameters        : numpy.array\n                        Optimized parameters.\n    \"\"\"\n    t = tqdm(range(iteration_number))\n    for i in t:\n        gradient = np.zeros(parameters.shape[0])\n        for j in range(input_data.shape[0]):\n            x = input_data[j]\n            y = ground_truth_data[j]\n            gradient = gradient + gradient_function(x, y, function, parameters)\n        parameters = parameters - learning_rate * gradient / input_data.shape[0]\n        loss = loss_function(ground_truth_data, function(input_data, parameters))\n        description = 'Iteration number:{}, loss:{:0.4f}, parameters:{}'.format(i, loss, np.round(parameters, 2))\n        t.set_description(description)\n    return parameters\n
"},{"location":"odak/fit/#odak.fit.least_square_1d","title":"least_square_1d(x, y)","text":"

A function to fit a line to given x and y data (y=mx+n). Inspired from: https://mmas.github.io/least-squares-fitting-numpy-scipy

Parameters:

  • x \u2013
         1D input data.\n
  • y \u2013
         1D output data.\n

Returns:

  • parameters ( array ) \u2013

    Parameters of m and n in a line (y=mx+n).

Source code in odak/fit/__init__.py
def least_square_1d(x, y):\n    \"\"\"\n    A function to fit a line to given x and y data (y=mx+n). Inspired from: https://mmas.github.io/least-squares-fitting-numpy-scipy\n\n    Parameters\n    ----------\n    x          : numpy.array\n                 1D input data.\n    y          : numpy.array\n                 1D output data.\n\n    Returns\n    -------\n    parameters : numpy.array\n                 Parameters of m and n in a line (y=mx+n).\n    \"\"\"\n    w = np.vstack([x, np.ones(x.shape[0])]).T\n    parameters = np.dot(np.linalg.inv(np.dot(w.T, w)), np.dot(w.T, y))\n    return parameters\n
"},{"location":"odak/fit/#odak.fit.perceptron","title":"perceptron(x, y, learning_rate=0.1, iteration_number=100)","text":"

A function to train a perceptron model.

Parameters:

  • x \u2013
               Input X-Y pairs [m x 2].\n
  • y \u2013
               Labels for the input data [m x 1]\n
  • learning_rate \u2013
               Learning rate.\n
  • iteration_number (int, default: 100 ) \u2013
               Iteration number.\n

Returns:

  • weights ( array ) \u2013

    Trained weights of our model [3 x 1].

Source code in odak/fit/__init__.py
def perceptron(x, y, learning_rate = 0.1, iteration_number = 100):\n    \"\"\"\n    A function to train a perceptron model.\n\n    Parameters\n    ----------\n    x                : numpy.array\n                       Input X-Y pairs [m x 2].\n    y                : numpy.array\n                       Labels for the input data [m x 1]\n    learning_rate    : float\n                       Learning rate.\n    iteration_number : int\n                       Iteration number.\n\n    Returns\n    -------\n    weights          : numpy.array\n                       Trained weights of our model [3 x 1].\n    \"\"\"\n    weights = np.zeros((x.shape[1] + 1, 1))\n    t = tqdm(range(iteration_number))\n    for step in t:\n        unsuccessful = 0\n        for data_id in range(x.shape[0]):\n            x_i = np.insert(x[data_id], 0, 1).reshape(-1, 1)\n            y_i = y[data_id]\n            y_hat = threshold_linear_model(x_i, weights)\n            if y_hat - y_i != 0:\n                unsuccessful += 1\n                weights = weights + learning_rate * (y_i - y_hat) * x_i \n            description = 'Unsuccessful count: {}/{}'.format(unsuccessful, x.shape[0])\n    return weights\n
"},{"location":"odak/fit/#odak.fit.threshold_linear_model","title":"threshold_linear_model(x, w, threshold=0)","text":"

A function for thresholding a linear model described with a dot product.

Parameters:

  • x \u2013
               Input data [3 x 1].\n
  • w \u2013
               Weights [3 x 1].\n
  • threshold \u2013
               Value for thresholding.\n

Returns:

  • result ( int ) \u2013

    Estimated class of the input data. It could either be one or zero.

Source code in odak/fit/__init__.py
def threshold_linear_model(x, w, threshold = 0):\n    \"\"\"\n    A function for thresholding a linear model described with a dot product.\n\n    Parameters\n    ----------\n    x                : numpy.array\n                       Input data [3 x 1].\n    w                : numpy.array\n                       Weights [3 x 1].\n    threshold        : float\n                       Value for thresholding.\n\n    Returns\n    -------\n    result           : int\n                       Estimated class of the input data. It could either be one or zero.\n    \"\"\"\n    value = np.dot(x.T, w)\n    result = 0\n    if value >= threshold:\n       result = 1\n    return result\n
"},{"location":"odak/learn_lensless/","title":"odak.learn.lensless","text":""},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track","title":"spec_track","text":"

Bases: Module

The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Ak\u015fit. \"SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging.\" In SIGGRAPH Asia 2024 Posters (SA Posters '24).

Parameters:

  • reduction (str, default: 'sum' ) \u2013
        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.\n
  • device \u2013
        Device to run the model on. Default is CPU.\n
Source code in odak/learn/lensless/models.py
class spec_track(nn.Module):\n    \"\"\"\n    The learned holography model used in the paper, Ziyang Chen and Mustafa Dogan and Josef Spjut and Kaan Ak\u015fit. \"SpecTrack: Learned Multi-Rotation Tracking via Speckle Imaging.\" In SIGGRAPH Asia 2024 Posters (SA Posters '24).\n\n    Parameters\n    ----------\n    reduction : str\n                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.\n    device    : torch.device\n                Device to run the model on. Default is CPU.\n    \"\"\"\n    def __init__(\n                 self,\n                 reduction = 'sum',\n                 device = torch.device('cpu')\n                ):\n        super(spec_track, self).__init__()\n        self.device = device\n        self.init_layers()\n        self.reduction = reduction\n        self.l2 = torch.nn.MSELoss(reduction = self.reduction)\n        self.l1 = torch.nn.L1Loss(reduction = self.reduction)\n        self.train_history = []\n        self.validation_history = []\n\n\n    def init_layers(self):\n        \"\"\"\n        Initialize the layers of the network.\n        \"\"\"\n        # Convolutional layers with batch normalization and pooling\n        self.network = nn.Sequential(OrderedDict([\n            ('conv1', nn.Conv2d(5, 32, kernel_size=3, padding=1)),\n            ('bn1', nn.BatchNorm2d(32)),\n            ('relu1', nn.ReLU()),\n            ('pool1', nn.MaxPool2d(kernel_size=3)),\n\n            ('conv2', nn.Conv2d(32, 64, kernel_size=5, padding=1)),\n            ('bn2', nn.BatchNorm2d(64)),\n            ('relu2', nn.ReLU()),\n            ('pool2', nn.MaxPool2d(kernel_size=3)),\n\n            ('conv3', nn.Conv2d(64, 128, kernel_size=7, padding=1)),\n            ('bn3', nn.BatchNorm2d(128)),\n            ('relu3', nn.ReLU()),\n            ('pool3', nn.MaxPool2d(kernel_size=3)),\n\n            ('flatten', nn.Flatten()),\n\n            ('fc1', nn.Linear(6400, 2048)),\n            ('fc_bn1', nn.BatchNorm1d(2048)),\n            ('relu_fc1', nn.ReLU()),\n\n            ('fc2', nn.Linear(2048, 1024)),\n            ('fc_bn2', nn.BatchNorm1d(1024)),\n            ('relu_fc2', nn.ReLU()),\n\n            ('fc3', nn.Linear(1024, 512)),\n            ('fc_bn3', nn.BatchNorm1d(512)),\n            ('relu_fc3', nn.ReLU()),\n\n            ('fc4', nn.Linear(512, 128)),\n            ('fc_bn4', nn.BatchNorm1d(128)),\n            ('relu_fc4', nn.ReLU()),\n\n            ('fc5', nn.Linear(128, 3))\n        ])).to(self.device)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the network.\n\n        Parameters\n        ----------\n        x : torch.Tensor\n            Input tensor.\n\n        Returns\n        -------\n        torch.Tensor\n            Output tensor.\n        \"\"\"\n        return self.network(x)\n\n\n    def evaluate(self, input_data, ground_truth, weights = [100., 1.]):\n        \"\"\"\n        Evaluate the model's performance.\n\n        Parameters\n        ----------\n        input_data    : torch.Tensor\n                        Predicted data from the model.\n        ground_truth  : torch.Tensor\n                        Ground truth data.\n        weights       : list\n                        Weights for L2 and L1 losses. Default is [100., 1.].\n\n        Returns\n        -------\n        torch.Tensor\n            Combined weighted loss.\n        \"\"\"\n        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)\n        return loss\n\n\n    def fit(self, trainloader, testloader, number_of_epochs=100, learning_rate=1e-5, weight_decay=1e-5, directory='./output'):\n        \"\"\"\n        Train the model.\n\n        Parameters\n        ----------\n        trainloader      : torch.utils.data.DataLoader\n                           Training data loader.\n        testloader       : torch.utils.data.DataLoader\n                           Testing data loader.\n        number_of_epochs : int\n                           Number of epochs to train for. Default is 100.\n        learning_rate    : float\n                           Learning rate for the optimizer. Default is 1e-5.\n        weight_decay     : float\n                           Weight decay for the optimizer. Default is 1e-5.\n        directory        : str\n                           Directory to save the model weights. Default is './output'.\n        \"\"\"\n        makedirs(directory, exist_ok=True)\n        makedirs(join(directory, \"log\"), exist_ok=True)\n\n        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)\n        best_val_loss = float('inf')\n\n        for epoch in range(number_of_epochs):\n            # Training phase\n            self.train()\n            train_loss = 0.0\n            train_batches = 0\n            train_pbar = tqdm(trainloader, desc=f\"Epoch {epoch+1}/{number_of_epochs} [Train]\", leave=False, dynamic_ncols=True)\n\n            for batch, labels in train_pbar:\n                self.optimizer.zero_grad()\n                batch, labels = batch.to(self.device), labels.to(self.device)\n                predicts = torch.squeeze(self.forward(batch))\n                loss = self.evaluate(predicts, labels)\n                loss.backward()\n                self.optimizer.step()\n\n                train_loss += loss.item()\n                train_batches += 1\n                train_pbar.set_postfix({'Loss': f\"{loss.item():.4f}\"})\n\n            avg_train_loss = train_loss / train_batches\n            self.train_history.append(avg_train_loss)\n\n            # Validation phase\n            self.eval()\n            val_loss = 0.0\n            val_batches = 0\n            val_pbar = tqdm(testloader, desc=f\"Epoch {epoch+1}/{number_of_epochs} [Val]\", leave=False, dynamic_ncols=True)\n\n            with torch.no_grad():\n                for batch, labels in val_pbar:\n                    batch, labels = batch.to(self.device), labels.to(self.device)\n                    predicts = torch.squeeze(self.forward(batch), dim=1)\n                    loss = self.evaluate(predicts, labels)\n\n                    val_loss += loss.item()\n                    val_batches += 1\n                    val_pbar.set_postfix({'Loss': f\"{loss.item():.4f}\"})\n\n            avg_val_loss = val_loss / val_batches\n            self.validation_history.append(avg_val_loss)\n\n            # Print epoch summary\n            print(f\"Epoch {epoch+1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}\")\n\n            # Save best model\n            if avg_val_loss < best_val_loss:\n                best_val_loss = avg_val_loss\n                self.save_weights(join(directory, f\"best_model_epoch_{epoch+1}.pt\"))\n                print(f\"Best model saved at epoch {epoch+1}\")\n\n        # Save training history\n        torch.save(self.train_history, join(directory, \"log\", \"train_log.pt\"))\n        torch.save(self.validation_history, join(directory, \"log\", \"validation_log.pt\"))\n        print(\"Training completed. History saved.\")\n\n\n    def save_weights(self, filename = './weights.pt'):\n        \"\"\"\n        Save the current weights of the network to a file.\n\n        Parameters\n        ----------\n        filename : str\n                   Path to save the weights. Default is './weights.pt'.\n        \"\"\"\n        torch.save(self.network.state_dict(), os.path.expanduser(filename))\n\n\n    def load_weights(self, filename = './weights.pt'):\n        \"\"\"\n        Load weights for the network from a file.\n\n        Parameters\n        ----------\n        filename : str\n                   Path to load the weights from. Default is './weights.pt'.\n        \"\"\"\n        self.network.load_state_dict(torch.load(os.path.expanduser(filename), weights_only = True))\n        self.network.eval()\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.evaluate","title":"evaluate(input_data, ground_truth, weights=[100.0, 1.0])","text":"

Evaluate the model's performance.

Parameters:

  • input_data \u2013
            Predicted data from the model.\n
  • ground_truth \u2013
            Ground truth data.\n
  • weights \u2013
            Weights for L2 and L1 losses. Default is [100., 1.].\n

Returns:

  • Tensor \u2013

    Combined weighted loss.

Source code in odak/learn/lensless/models.py
def evaluate(self, input_data, ground_truth, weights = [100., 1.]):\n    \"\"\"\n    Evaluate the model's performance.\n\n    Parameters\n    ----------\n    input_data    : torch.Tensor\n                    Predicted data from the model.\n    ground_truth  : torch.Tensor\n                    Ground truth data.\n    weights       : list\n                    Weights for L2 and L1 losses. Default is [100., 1.].\n\n    Returns\n    -------\n    torch.Tensor\n        Combined weighted loss.\n    \"\"\"\n    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)\n    return loss\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.fit","title":"fit(trainloader, testloader, number_of_epochs=100, learning_rate=1e-05, weight_decay=1e-05, directory='./output')","text":"

Train the model.

Parameters:

  • trainloader \u2013
               Training data loader.\n
  • testloader \u2013
               Testing data loader.\n
  • number_of_epochs (int, default: 100 ) \u2013
               Number of epochs to train for. Default is 100.\n
  • learning_rate \u2013
               Learning rate for the optimizer. Default is 1e-5.\n
  • weight_decay \u2013
               Weight decay for the optimizer. Default is 1e-5.\n
  • directory \u2013
               Directory to save the model weights. Default is './output'.\n
Source code in odak/learn/lensless/models.py
def fit(self, trainloader, testloader, number_of_epochs=100, learning_rate=1e-5, weight_decay=1e-5, directory='./output'):\n    \"\"\"\n    Train the model.\n\n    Parameters\n    ----------\n    trainloader      : torch.utils.data.DataLoader\n                       Training data loader.\n    testloader       : torch.utils.data.DataLoader\n                       Testing data loader.\n    number_of_epochs : int\n                       Number of epochs to train for. Default is 100.\n    learning_rate    : float\n                       Learning rate for the optimizer. Default is 1e-5.\n    weight_decay     : float\n                       Weight decay for the optimizer. Default is 1e-5.\n    directory        : str\n                       Directory to save the model weights. Default is './output'.\n    \"\"\"\n    makedirs(directory, exist_ok=True)\n    makedirs(join(directory, \"log\"), exist_ok=True)\n\n    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, weight_decay=weight_decay)\n    best_val_loss = float('inf')\n\n    for epoch in range(number_of_epochs):\n        # Training phase\n        self.train()\n        train_loss = 0.0\n        train_batches = 0\n        train_pbar = tqdm(trainloader, desc=f\"Epoch {epoch+1}/{number_of_epochs} [Train]\", leave=False, dynamic_ncols=True)\n\n        for batch, labels in train_pbar:\n            self.optimizer.zero_grad()\n            batch, labels = batch.to(self.device), labels.to(self.device)\n            predicts = torch.squeeze(self.forward(batch))\n            loss = self.evaluate(predicts, labels)\n            loss.backward()\n            self.optimizer.step()\n\n            train_loss += loss.item()\n            train_batches += 1\n            train_pbar.set_postfix({'Loss': f\"{loss.item():.4f}\"})\n\n        avg_train_loss = train_loss / train_batches\n        self.train_history.append(avg_train_loss)\n\n        # Validation phase\n        self.eval()\n        val_loss = 0.0\n        val_batches = 0\n        val_pbar = tqdm(testloader, desc=f\"Epoch {epoch+1}/{number_of_epochs} [Val]\", leave=False, dynamic_ncols=True)\n\n        with torch.no_grad():\n            for batch, labels in val_pbar:\n                batch, labels = batch.to(self.device), labels.to(self.device)\n                predicts = torch.squeeze(self.forward(batch), dim=1)\n                loss = self.evaluate(predicts, labels)\n\n                val_loss += loss.item()\n                val_batches += 1\n                val_pbar.set_postfix({'Loss': f\"{loss.item():.4f}\"})\n\n        avg_val_loss = val_loss / val_batches\n        self.validation_history.append(avg_val_loss)\n\n        # Print epoch summary\n        print(f\"Epoch {epoch+1}/{number_of_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}\")\n\n        # Save best model\n        if avg_val_loss < best_val_loss:\n            best_val_loss = avg_val_loss\n            self.save_weights(join(directory, f\"best_model_epoch_{epoch+1}.pt\"))\n            print(f\"Best model saved at epoch {epoch+1}\")\n\n    # Save training history\n    torch.save(self.train_history, join(directory, \"log\", \"train_log.pt\"))\n    torch.save(self.validation_history, join(directory, \"log\", \"validation_log.pt\"))\n    print(\"Training completed. History saved.\")\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.forward","title":"forward(x)","text":"

Forward pass of the network.

Parameters:

  • x (Tensor) \u2013

    Input tensor.

Returns:

  • Tensor \u2013

    Output tensor.

Source code in odak/learn/lensless/models.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the network.\n\n    Parameters\n    ----------\n    x : torch.Tensor\n        Input tensor.\n\n    Returns\n    -------\n    torch.Tensor\n        Output tensor.\n    \"\"\"\n    return self.network(x)\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.init_layers","title":"init_layers()","text":"

Initialize the layers of the network.

Source code in odak/learn/lensless/models.py
def init_layers(self):\n    \"\"\"\n    Initialize the layers of the network.\n    \"\"\"\n    # Convolutional layers with batch normalization and pooling\n    self.network = nn.Sequential(OrderedDict([\n        ('conv1', nn.Conv2d(5, 32, kernel_size=3, padding=1)),\n        ('bn1', nn.BatchNorm2d(32)),\n        ('relu1', nn.ReLU()),\n        ('pool1', nn.MaxPool2d(kernel_size=3)),\n\n        ('conv2', nn.Conv2d(32, 64, kernel_size=5, padding=1)),\n        ('bn2', nn.BatchNorm2d(64)),\n        ('relu2', nn.ReLU()),\n        ('pool2', nn.MaxPool2d(kernel_size=3)),\n\n        ('conv3', nn.Conv2d(64, 128, kernel_size=7, padding=1)),\n        ('bn3', nn.BatchNorm2d(128)),\n        ('relu3', nn.ReLU()),\n        ('pool3', nn.MaxPool2d(kernel_size=3)),\n\n        ('flatten', nn.Flatten()),\n\n        ('fc1', nn.Linear(6400, 2048)),\n        ('fc_bn1', nn.BatchNorm1d(2048)),\n        ('relu_fc1', nn.ReLU()),\n\n        ('fc2', nn.Linear(2048, 1024)),\n        ('fc_bn2', nn.BatchNorm1d(1024)),\n        ('relu_fc2', nn.ReLU()),\n\n        ('fc3', nn.Linear(1024, 512)),\n        ('fc_bn3', nn.BatchNorm1d(512)),\n        ('relu_fc3', nn.ReLU()),\n\n        ('fc4', nn.Linear(512, 128)),\n        ('fc_bn4', nn.BatchNorm1d(128)),\n        ('relu_fc4', nn.ReLU()),\n\n        ('fc5', nn.Linear(128, 3))\n    ])).to(self.device)\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.load_weights","title":"load_weights(filename='./weights.pt')","text":"

Load weights for the network from a file.

Parameters:

  • filename (str, default: './weights.pt' ) \u2013
       Path to load the weights from. Default is './weights.pt'.\n
Source code in odak/learn/lensless/models.py
def load_weights(self, filename = './weights.pt'):\n    \"\"\"\n    Load weights for the network from a file.\n\n    Parameters\n    ----------\n    filename : str\n               Path to load the weights from. Default is './weights.pt'.\n    \"\"\"\n    self.network.load_state_dict(torch.load(os.path.expanduser(filename), weights_only = True))\n    self.network.eval()\n
"},{"location":"odak/learn_lensless/#odak.learn.lensless.models.spec_track.save_weights","title":"save_weights(filename='./weights.pt')","text":"

Save the current weights of the network to a file.

Parameters:

  • filename (str, default: './weights.pt' ) \u2013
       Path to save the weights. Default is './weights.pt'.\n
Source code in odak/learn/lensless/models.py
def save_weights(self, filename = './weights.pt'):\n    \"\"\"\n    Save the current weights of the network to a file.\n\n    Parameters\n    ----------\n    filename : str\n               Path to save the weights. Default is './weights.pt'.\n    \"\"\"\n    torch.save(self.network.state_dict(), os.path.expanduser(filename))\n
"},{"location":"odak/learn_models/","title":"odak.learn.models","text":"

odak.learn.models

Provides necessary definitions for components used in machine learning and deep learning.

"},{"location":"odak/learn_models/#odak.learn.models.channel_gate","title":"channel_gate","text":"

Bases: Module

Channel attention module with various pooling strategies. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class channel_gate(torch.nn.Module):\n    \"\"\"\n    Channel attention module with various pooling strategies.\n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max']\n                ):\n        \"\"\"\n        Initializes the channel gate module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the intermediate layer.\n        pool_types      : list\n                          List of pooling operations to apply.\n        \"\"\"\n        super().__init__()\n        self.gate_channels = gate_channels\n        hidden_channels = gate_channels // reduction_ratio\n        if hidden_channels == 0:\n            hidden_channels = 1\n        self.mlp = torch.nn.Sequential(\n                                       convolutional_block_attention.Flatten(),\n                                       torch.nn.Linear(gate_channels, hidden_channels),\n                                       torch.nn.ReLU(),\n                                       torch.nn.Linear(hidden_channels, gate_channels)\n                                      )\n        self.pool_types = pool_types\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the ChannelGate module.\n\n        Applies channel-wise attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the ChannelGate module.\n\n        Returns\n        -------\n        output       : torch.tensor\n                       Output tensor after applying channel attention.\n        \"\"\"\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type == 'avg':\n                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            elif pool_type == 'max':\n                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            channel_att_raw = self.mlp(pool)\n            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n        output = x * scale\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.channel_gate.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'])","text":"

Initializes the channel gate module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the intermediate layer.\n
  • pool_types \u2013
              List of pooling operations to apply.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max']\n            ):\n    \"\"\"\n    Initializes the channel gate module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the intermediate layer.\n    pool_types      : list\n                      List of pooling operations to apply.\n    \"\"\"\n    super().__init__()\n    self.gate_channels = gate_channels\n    hidden_channels = gate_channels // reduction_ratio\n    if hidden_channels == 0:\n        hidden_channels = 1\n    self.mlp = torch.nn.Sequential(\n                                   convolutional_block_attention.Flatten(),\n                                   torch.nn.Linear(gate_channels, hidden_channels),\n                                   torch.nn.ReLU(),\n                                   torch.nn.Linear(hidden_channels, gate_channels)\n                                  )\n    self.pool_types = pool_types\n
"},{"location":"odak/learn_models/#odak.learn.models.channel_gate.forward","title":"forward(x)","text":"

Forward pass of the ChannelGate module.

Applies channel-wise attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the ChannelGate module.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor after applying channel attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the ChannelGate module.\n\n    Applies channel-wise attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the ChannelGate module.\n\n    Returns\n    -------\n    output       : torch.tensor\n                   Output tensor after applying channel attention.\n    \"\"\"\n    channel_att_sum = None\n    for pool_type in self.pool_types:\n        if pool_type == 'avg':\n            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        elif pool_type == 'max':\n            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        channel_att_raw = self.mlp(pool)\n        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n    output = x * scale\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.convolution_layer","title":"convolution_layer","text":"

Bases: Module

A convolution layer.

Source code in odak/learn/models/components.py
class convolution_layer(torch.nn.Module):\n    \"\"\"\n    A convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 bias = False,\n                 stride = 1,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        layers = [\n            torch.nn.Conv2d(\n                            input_channels,\n                            output_channels,\n                            kernel_size = kernel_size,\n                            stride = stride,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           )\n        ]\n        if normalization:\n            layers.append(torch.nn.BatchNorm2d(output_channels))\n        if activation:\n            layers.append(activation)\n        self.model = torch.nn.Sequential(*layers)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.convolution_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             bias = False,\n             stride = 1,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    layers = [\n        torch.nn.Conv2d(\n                        input_channels,\n                        output_channels,\n                        kernel_size = kernel_size,\n                        stride = stride,\n                        padding = kernel_size // 2,\n                        bias = bias\n                       )\n    ]\n    if normalization:\n        layers.append(torch.nn.BatchNorm2d(output_channels))\n    if activation:\n        layers.append(activation)\n    self.model = torch.nn.Sequential(*layers)\n
"},{"location":"odak/learn_models/#odak.learn.models.convolution_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.convolutional_block_attention","title":"convolutional_block_attention","text":"

Bases: Module

Convolutional Block Attention Module (CBAM) class. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class convolutional_block_attention(torch.nn.Module):\n    \"\"\"\n    Convolutional Block Attention Module (CBAM) class. \n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max'], \n                 no_spatial = False\n                ):\n        \"\"\"\n        Initializes the convolutional block attention module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the channel attention.\n        pool_types      : list\n                          List of pooling operations to apply for channel attention.\n        no_spatial      : bool\n                          If True, spatial attention is not applied.\n        \"\"\"\n        super(convolutional_block_attention, self).__init__()\n        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial = no_spatial\n        if not no_spatial:\n            self.spatial_gate = spatial_gate()\n\n\n    class Flatten(torch.nn.Module):\n        \"\"\"\n        Flattens the input tensor to a 2D matrix.\n        \"\"\"\n        def forward(self, x):\n            return x.view(x.size(0), -1)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the convolutional block attention module.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the CBAM module.\n\n        Returns\n        -------\n        x_out        : torch.tensor\n                       Output tensor after applying channel and spatial attention.\n        \"\"\"\n        x_out = self.channel_gate(x)\n        if not self.no_spatial:\n            x_out = self.spatial_gate(x_out)\n        return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.convolutional_block_attention.Flatten","title":"Flatten","text":"

Bases: Module

Flattens the input tensor to a 2D matrix.

Source code in odak/learn/models/components.py
class Flatten(torch.nn.Module):\n    \"\"\"\n    Flattens the input tensor to a 2D matrix.\n    \"\"\"\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n
"},{"location":"odak/learn_models/#odak.learn.models.convolutional_block_attention.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False)","text":"

Initializes the convolutional block attention module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the channel attention.\n
  • pool_types \u2013
              List of pooling operations to apply for channel attention.\n
  • no_spatial \u2013
              If True, spatial attention is not applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max'], \n             no_spatial = False\n            ):\n    \"\"\"\n    Initializes the convolutional block attention module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the channel attention.\n    pool_types      : list\n                      List of pooling operations to apply for channel attention.\n    no_spatial      : bool\n                      If True, spatial attention is not applied.\n    \"\"\"\n    super(convolutional_block_attention, self).__init__()\n    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n    self.no_spatial = no_spatial\n    if not no_spatial:\n        self.spatial_gate = spatial_gate()\n
"},{"location":"odak/learn_models/#odak.learn.models.convolutional_block_attention.forward","title":"forward(x)","text":"

Forward pass of the convolutional block attention module.

Parameters:

  • x \u2013
           Input tensor to the CBAM module.\n

Returns:

  • x_out ( tensor ) \u2013

    Output tensor after applying channel and spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the convolutional block attention module.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the CBAM module.\n\n    Returns\n    -------\n    x_out        : torch.tensor\n                   Output tensor after applying channel and spatial attention.\n    \"\"\"\n    x_out = self.channel_gate(x)\n    if not self.no_spatial:\n        x_out = self.spatial_gate(x_out)\n    return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.double_convolution","title":"double_convolution","text":"

Bases: Module

A double convolution layer.

Source code in odak/learn/models/components.py
class double_convolution(torch.nn.Module):\n    \"\"\"\n    A double convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = None,\n                 output_channels = 2,\n                 kernel_size = 3, \n                 bias = False,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        Double convolution model.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of channels in the hidden layer between two convolutions.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        if isinstance(mid_channels, type(None)):\n            mid_channels = output_channels\n        self.activation = activation\n        self.model = torch.nn.Sequential(\n                                         convolution_layer(\n                                                           input_channels = input_channels,\n                                                           output_channels = mid_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          ),\n                                         convolution_layer(\n                                                           input_channels = mid_channels,\n                                                           output_channels = output_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          )\n                                        )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.double_convolution.__init__","title":"__init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU())","text":"

Double convolution model.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of channels in the hidden layer between two convolutions.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = None,\n             output_channels = 2,\n             kernel_size = 3, \n             bias = False,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    Double convolution model.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of channels in the hidden layer between two convolutions.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    if isinstance(mid_channels, type(None)):\n        mid_channels = output_channels\n    self.activation = activation\n    self.model = torch.nn.Sequential(\n                                     convolution_layer(\n                                                       input_channels = input_channels,\n                                                       output_channels = mid_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      ),\n                                     convolution_layer(\n                                                       input_channels = mid_channels,\n                                                       output_channels = output_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      )\n                                    )\n
"},{"location":"odak/learn_models/#odak.learn.models.double_convolution.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.downsample_layer","title":"downsample_layer","text":"

Bases: Module

A downscaling component followed by a double convolution.

Source code in odak/learn/models/components.py
class downsample_layer(torch.nn.Module):\n    \"\"\"\n    A downscaling component followed by a double convolution.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.maxpool_conv = torch.nn.Sequential(\n                                                torch.nn.MaxPool2d(2),\n                                                double_convolution(\n                                                                   input_channels = input_channels,\n                                                                   mid_channels = output_channels,\n                                                                   output_channels = output_channels,\n                                                                   kernel_size = kernel_size,\n                                                                   bias = bias,\n                                                                   activation = activation\n                                                                  )\n                                               )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x              : torch.tensor\n                         First input data.\n\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.maxpool_conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.downsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.maxpool_conv = torch.nn.Sequential(\n                                            torch.nn.MaxPool2d(2),\n                                            double_convolution(\n                                                               input_channels = input_channels,\n                                                               mid_channels = output_channels,\n                                                               output_channels = output_channels,\n                                                               kernel_size = kernel_size,\n                                                               bias = bias,\n                                                               activation = activation\n                                                              )\n                                           )\n
"},{"location":"odak/learn_models/#odak.learn.models.downsample_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
             First input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x              : torch.tensor\n                     First input data.\n\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.maxpool_conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.global_feature_module","title":"global_feature_module","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies them to another input tensor via learned transformations.

Source code in odak/learn/models/components.py
class global_feature_module(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies them to another input tensor via learned transformations.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 mid_channels,\n                 output_channels,\n                 kernel_size,\n                 bias = False,\n                 normalization = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels  : int\n                          Number of mid channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.transformations_1 = global_transformations(input_channels, output_channels)\n        self.global_features_1 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.global_features_2 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.transformations_2 = global_transformations(input_channels, output_channels)\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        global_tensor_1 = self.transformations_1(x1, x2)\n        y1 = self.global_features_1(global_tensor_1)\n        y2 = self.global_features_2(y1)\n        global_tensor_2 = self.transformations_2(y1, y2)\n        return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.global_feature_module.__init__","title":"__init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU())","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of mid channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             mid_channels,\n             output_channels,\n             kernel_size,\n             bias = False,\n             normalization = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels  : int\n                      Number of mid channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.transformations_1 = global_transformations(input_channels, output_channels)\n    self.global_features_1 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.global_features_2 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.transformations_2 = global_transformations(input_channels, output_channels)\n
"},{"location":"odak/learn_models/#odak.learn.models.global_feature_module.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    global_tensor_1 = self.transformations_1(x1, x2)\n    y1 = self.global_features_1(global_tensor_1)\n    y2 = self.global_features_2(y1)\n    global_tensor_2 = self.transformations_2(y1, y2)\n    return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.global_transformations","title":"global_transformations","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies learned transformations to another input tensor.

This implementation is adapted from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Reference: J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"

Source code in odak/learn/models/components.py
class global_transformations(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies learned transformations to another input tensor.\n\n    This implementation is adapted from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Reference:\n    J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        \"\"\"\n        super().__init__()\n        self.global_feature_1 = torch.nn.Sequential(\n            torch.nn.Linear(input_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True),\n        )\n        self.global_feature_2 = torch.nn.Sequential(\n            torch.nn.Linear(output_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True)\n        )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        y = torch.mean(x2, dim = (2, 3))\n        y1 = self.global_feature_1(y)\n        y2 = self.global_feature_2(y1)\n        y1 = y1.unsqueeze(2).unsqueeze(3)\n        y2 = y2.unsqueeze(2).unsqueeze(3)\n        result = x1 * y1 + y2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.global_transformations.__init__","title":"__init__(input_channels, output_channels)","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    \"\"\"\n    super().__init__()\n    self.global_feature_1 = torch.nn.Sequential(\n        torch.nn.Linear(input_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True),\n    )\n    self.global_feature_2 = torch.nn.Sequential(\n        torch.nn.Linear(output_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True)\n    )\n
"},{"location":"odak/learn_models/#odak.learn.models.global_transformations.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    y = torch.mean(x2, dim = (2, 3))\n    y1 = self.global_feature_1(y)\n    y2 = self.global_feature_2(y1)\n    y1 = y1.unsqueeze(2).unsqueeze(3)\n    y2 = y2.unsqueeze(2).unsqueeze(3)\n    result = x1 * y1 + y2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.multi_layer_perceptron","title":"multi_layer_perceptron","text":"

Bases: Module

A multi-layer perceptron model.

Source code in odak/learn/models/models.py
class multi_layer_perceptron(torch.nn.Module):\n    \"\"\"\n    A multi-layer perceptron model.\n    \"\"\"\n\n    def __init__(self,\n                 dimensions,\n                 activation = torch.nn.ReLU(),\n                 bias = False,\n                 model_type = 'conventional',\n                 siren_multiplier = 1.,\n                 input_multiplier = None\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        dimensions        : list\n                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n        activation        : torch.nn\n                            Nonlinear activation function.\n                            Default is `torch.nn.ReLU()`.\n        bias              : bool\n                            If set to True, linear layers will include biases.\n        siren_multiplier  : float\n                            When using `SIREN` model type, this parameter functions as a hyperparameter.\n                            The original SIREN work uses 30.\n                            You can bypass this parameter by providing input that are not normalized and larger then one.\n        input_multiplier  : float\n                            Initial value of the input multiplier before the very first layer.\n        model_type        : str\n                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                            `conventional` refers to a standard multi layer perceptron.\n                            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n        \"\"\"\n        super(multi_layer_perceptron, self).__init__()\n        self.activation = activation\n        self.bias = bias\n        self.model_type = model_type\n        self.layers = torch.nn.ModuleList()\n        self.siren_multiplier = siren_multiplier\n        self.dimensions = dimensions\n        for i in range(len(self.dimensions) - 1):\n            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n        if not isinstance(input_multiplier, type(None)):\n            self.input_multiplier = torch.nn.ParameterList()\n            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n        if self.model_type == 'FILM SIREN':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n        if self.model_type == 'Gaussian':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        if hasattr(self, 'input_multiplier'):\n            result = x * self.input_multiplier[0]\n        else:\n            result = x\n        for layer_id, layer in enumerate(self.layers[:-1]):\n            result = layer(result)\n            if self.model_type == 'conventional':\n                result = self.activation(result)\n            elif self.model_type == 'swish':\n                resutl = swish(result)\n            elif self.model_type == 'SIREN':\n                result = torch.sin(result * self.siren_multiplier)\n            elif self.model_type == 'FILM SIREN':\n                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n            elif self.model_type == 'Gaussian': \n                result = gaussian(result, self.alpha[layer_id][0])\n        result = self.layers[-1](result)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.multi_layer_perceptron.__init__","title":"__init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None)","text":"

Parameters:

  • dimensions \u2013
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n
  • activation \u2013
                Nonlinear activation function.\n            Default is `torch.nn.ReLU()`.\n
  • bias \u2013
                If set to True, linear layers will include biases.\n
  • siren_multiplier \u2013
                When using `SIREN` model type, this parameter functions as a hyperparameter.\n            The original SIREN work uses 30.\n            You can bypass this parameter by providing input that are not normalized and larger then one.\n
  • input_multiplier \u2013
                Initial value of the input multiplier before the very first layer.\n
  • model_type \u2013
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n            `conventional` refers to a standard multi layer perceptron.\n            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n
Source code in odak/learn/models/models.py
def __init__(self,\n             dimensions,\n             activation = torch.nn.ReLU(),\n             bias = False,\n             model_type = 'conventional',\n             siren_multiplier = 1.,\n             input_multiplier = None\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    dimensions        : list\n                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n    activation        : torch.nn\n                        Nonlinear activation function.\n                        Default is `torch.nn.ReLU()`.\n    bias              : bool\n                        If set to True, linear layers will include biases.\n    siren_multiplier  : float\n                        When using `SIREN` model type, this parameter functions as a hyperparameter.\n                        The original SIREN work uses 30.\n                        You can bypass this parameter by providing input that are not normalized and larger then one.\n    input_multiplier  : float\n                        Initial value of the input multiplier before the very first layer.\n    model_type        : str\n                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                        `conventional` refers to a standard multi layer perceptron.\n                        For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                        For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n    \"\"\"\n    super(multi_layer_perceptron, self).__init__()\n    self.activation = activation\n    self.bias = bias\n    self.model_type = model_type\n    self.layers = torch.nn.ModuleList()\n    self.siren_multiplier = siren_multiplier\n    self.dimensions = dimensions\n    for i in range(len(self.dimensions) - 1):\n        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n    if not isinstance(input_multiplier, type(None)):\n        self.input_multiplier = torch.nn.ParameterList()\n        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n    if self.model_type == 'FILM SIREN':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n    if self.model_type == 'Gaussian':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n
"},{"location":"odak/learn_models/#odak.learn.models.multi_layer_perceptron.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    if hasattr(self, 'input_multiplier'):\n        result = x * self.input_multiplier[0]\n    else:\n        result = x\n    for layer_id, layer in enumerate(self.layers[:-1]):\n        result = layer(result)\n        if self.model_type == 'conventional':\n            result = self.activation(result)\n        elif self.model_type == 'swish':\n            resutl = swish(result)\n        elif self.model_type == 'SIREN':\n            result = torch.sin(result * self.siren_multiplier)\n        elif self.model_type == 'FILM SIREN':\n            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n        elif self.model_type == 'Gaussian': \n            result = gaussian(result, self.alpha[layer_id][0])\n    result = self.layers[-1](result)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.non_local_layer","title":"non_local_layer","text":"

Bases: Module

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

Source code in odak/learn/models/components.py
class non_local_layer(torch.nn.Module):\n    \"\"\"\n    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 1024,\n                 bottleneck_channels = 512,\n                 kernel_size = 1,\n                 bias = False,\n                ):\n        \"\"\"\n\n        Parameters\n        ----------\n        input_channels      : int\n                              Number of input channels.\n        bottleneck_channels : int\n                              Number of middle channels.\n        kernel_size         : int\n                              Kernel size.\n        bias                : bool \n                              Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super(non_local_layer, self).__init__()\n        self.input_channels = input_channels\n        self.bottleneck_channels = bottleneck_channels\n        self.g = torch.nn.Conv2d(\n                                 self.input_channels, \n                                 self.bottleneck_channels,\n                                 kernel_size = kernel_size,\n                                 padding = kernel_size // 2,\n                                 bias = bias\n                                )\n        self.W_z = torch.nn.Sequential(\n                                       torch.nn.Conv2d(\n                                                       self.bottleneck_channels,\n                                                       self.input_channels, \n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       padding = kernel_size // 2\n                                                      ),\n                                       torch.nn.BatchNorm2d(self.input_channels)\n                                      )\n        torch.nn.init.constant_(self.W_z[1].weight, 0)   \n        torch.nn.init.constant_(self.W_z[1].bias, 0)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model [zi = Wzyi + xi]\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          First input data.                       \n\n\n        Returns\n        ----------\n        z               : torch.tensor\n                          Estimated output.\n        \"\"\"\n        batch_size, channels, height, width = x.size()\n        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n        attn = torch.nn.functional.softmax(attn, dim=-1)\n        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n        W_y = self.W_z(y)\n        z = W_y + x\n        return z\n
"},{"location":"odak/learn_models/#odak.learn.models.non_local_layer.__init__","title":"__init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False)","text":"

Parameters:

  • input_channels \u2013
                  Number of input channels.\n
  • bottleneck_channels (int, default: 512 ) \u2013
                  Number of middle channels.\n
  • kernel_size \u2013
                  Kernel size.\n
  • bias \u2013
                  Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 1024,\n             bottleneck_channels = 512,\n             kernel_size = 1,\n             bias = False,\n            ):\n    \"\"\"\n\n    Parameters\n    ----------\n    input_channels      : int\n                          Number of input channels.\n    bottleneck_channels : int\n                          Number of middle channels.\n    kernel_size         : int\n                          Kernel size.\n    bias                : bool \n                          Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super(non_local_layer, self).__init__()\n    self.input_channels = input_channels\n    self.bottleneck_channels = bottleneck_channels\n    self.g = torch.nn.Conv2d(\n                             self.input_channels, \n                             self.bottleneck_channels,\n                             kernel_size = kernel_size,\n                             padding = kernel_size // 2,\n                             bias = bias\n                            )\n    self.W_z = torch.nn.Sequential(\n                                   torch.nn.Conv2d(\n                                                   self.bottleneck_channels,\n                                                   self.input_channels, \n                                                   kernel_size = kernel_size,\n                                                   bias = bias,\n                                                   padding = kernel_size // 2\n                                                  ),\n                                   torch.nn.BatchNorm2d(self.input_channels)\n                                  )\n    torch.nn.init.constant_(self.W_z[1].weight, 0)   \n    torch.nn.init.constant_(self.W_z[1].bias, 0)\n
"},{"location":"odak/learn_models/#odak.learn.models.non_local_layer.forward","title":"forward(x)","text":"

Forward model [zi = Wzyi + xi]

Parameters:

  • x \u2013
              First input data.\n

Returns:

  • z ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model [zi = Wzyi + xi]\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      First input data.                       \n\n\n    Returns\n    ----------\n    z               : torch.tensor\n                      Estimated output.\n    \"\"\"\n    batch_size, channels, height, width = x.size()\n    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n    attn = torch.nn.functional.softmax(attn, dim=-1)\n    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n    W_y = self.W_z(y)\n    z = W_y + x\n    return z\n
"},{"location":"odak/learn_models/#odak.learn.models.normalization","title":"normalization","text":"

Bases: Module

A normalization layer.

Source code in odak/learn/models/components.py
class normalization(torch.nn.Module):\n    \"\"\"\n    A normalization layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 dim = 1,\n                ):\n        \"\"\"\n        Normalization layer.\n\n\n        Parameters\n        ----------\n        dim             : int\n                          Dimension (axis) to normalize.\n        \"\"\"\n        super().__init__()\n        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = 1, keepdim = True)\n        result =  (x - mean) * (var + eps).rsqrt() * self.k\n        return result \n
"},{"location":"odak/learn_models/#odak.learn.models.normalization.__init__","title":"__init__(dim=1)","text":"

Normalization layer.

Parameters:

  • dim \u2013
              Dimension (axis) to normalize.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             dim = 1,\n            ):\n    \"\"\"\n    Normalization layer.\n\n\n    Parameters\n    ----------\n    dim             : int\n                      Dimension (axis) to normalize.\n    \"\"\"\n    super().__init__()\n    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n
"},{"location":"odak/learn_models/#odak.learn.models.normalization.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n    mean = torch.mean(x, dim = 1, keepdim = True)\n    result =  (x - mean) * (var + eps).rsqrt() * self.k\n    return result \n
"},{"location":"odak/learn_models/#odak.learn.models.positional_encoder","title":"positional_encoder","text":"

Bases: Module

A positional encoder module.

Source code in odak/learn/models/components.py
class positional_encoder(torch.nn.Module):\n    \"\"\"\n    A positional encoder module.\n    \"\"\"\n\n    def __init__(self, L):\n        \"\"\"\n        A positional encoder module.\n\n        Parameters\n        ----------\n        L                   : int\n                              Positional encoding level.\n        \"\"\"\n        super(positional_encoder, self).__init__()\n        self.L = L\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          Input data.\n\n        Returns\n        ----------\n        result          : torch.tensor\n                          Result of the forward operation\n        \"\"\"\n        B, C = x.shape\n        x = x.view(B, C, 1)\n        results = [x]\n        for i in range(1, self.L + 1):\n            freq = (2 ** i) * math.pi\n            cos_x = torch.cos(freq * x)\n            sin_x = torch.sin(freq * x)\n            results.append(cos_x)\n            results.append(sin_x)\n        results = torch.cat(results, dim=2)\n        results = results.permute(0, 2, 1)\n        results = results.reshape(B, -1)\n        return results\n
"},{"location":"odak/learn_models/#odak.learn.models.positional_encoder.__init__","title":"__init__(L)","text":"

A positional encoder module.

Parameters:

  • L \u2013
                  Positional encoding level.\n
Source code in odak/learn/models/components.py
def __init__(self, L):\n    \"\"\"\n    A positional encoder module.\n\n    Parameters\n    ----------\n    L                   : int\n                          Positional encoding level.\n    \"\"\"\n    super(positional_encoder, self).__init__()\n    self.L = L\n
"},{"location":"odak/learn_models/#odak.learn.models.positional_encoder.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
              Input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      Input data.\n\n    Returns\n    ----------\n    result          : torch.tensor\n                      Result of the forward operation\n    \"\"\"\n    B, C = x.shape\n    x = x.view(B, C, 1)\n    results = [x]\n    for i in range(1, self.L + 1):\n        freq = (2 ** i) * math.pi\n        cos_x = torch.cos(freq * x)\n        sin_x = torch.sin(freq * x)\n        results.append(cos_x)\n        results.append(sin_x)\n    results = torch.cat(results, dim=2)\n    results = results.permute(0, 2, 1)\n    results = results.reshape(B, -1)\n    return results\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_attention_layer","title":"residual_attention_layer","text":"

Bases: Module

A residual block with an attention layer.

Source code in odak/learn/models/components.py
class residual_attention_layer(torch.nn.Module):\n    \"\"\"\n    A residual block with an attention layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 1,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        An attention layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int or optioal\n                          Number of input channels.\n        output_channels : int or optional\n                          Number of middle channels.\n        kernel_size     : int or optional\n                          Kernel size.\n        bias            : bool or optional\n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn or optional\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution0 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.convolution1 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.final_layer = torch.nn.Sequential(\n                                               self.activation,\n                                               torch.nn.Conv2d(\n                                                               output_channels,\n                                                               output_channels,\n                                                               kernel_size = kernel_size,\n                                                               padding = kernel_size // 2,\n                                                               bias = bias\n                                                              )\n                                              )\n\n\n    def forward(self, x0, x1):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x0             : torch.tensor\n                         First input data.\n\n        x1             : torch.tensor\n                         Seconnd input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        y0 = self.convolution0(x0)\n        y1 = self.convolution1(x1)\n        y2 = torch.add(y0, y1)\n        result = self.final_layer(y2) * x0\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_attention_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU())","text":"

An attention layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int or optional, default: 2 ) \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 1,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    An attention layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int or optioal\n                      Number of input channels.\n    output_channels : int or optional\n                      Number of middle channels.\n    kernel_size     : int or optional\n                      Kernel size.\n    bias            : bool or optional\n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn or optional\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution0 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.convolution1 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.final_layer = torch.nn.Sequential(\n                                           self.activation,\n                                           torch.nn.Conv2d(\n                                                           output_channels,\n                                                           output_channels,\n                                                           kernel_size = kernel_size,\n                                                           padding = kernel_size // 2,\n                                                           bias = bias\n                                                          )\n                                          )\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_attention_layer.forward","title":"forward(x0, x1)","text":"

Forward model.

Parameters:

  • x0 \u2013
             First input data.\n
  • x1 \u2013
             Seconnd input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x0, x1):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x0             : torch.tensor\n                     First input data.\n\n    x1             : torch.tensor\n                     Seconnd input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    y0 = self.convolution0(x0)\n    y1 = self.convolution1(x1)\n    y2 = torch.add(y0, y1)\n    result = self.final_layer(y2) * x0\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_layer","title":"residual_layer","text":"

Bases: Module

A residual layer.

Source code in odak/learn/models/components.py
class residual_layer(torch.nn.Module):\n    \"\"\"\n    A residual layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = 16,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of middle channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution = double_convolution(\n                                              input_channels,\n                                              mid_channels = mid_channels,\n                                              output_channels = input_channels,\n                                              kernel_size = kernel_size,\n                                              bias = bias,\n                                              activation = activation\n                                             )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        x0 = self.convolution(x)\n        return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_layer.__init__","title":"__init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = 16,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of middle channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution = double_convolution(\n                                          input_channels,\n                                          mid_channels = mid_channels,\n                                          output_channels = input_channels,\n                                          kernel_size = kernel_size,\n                                          bias = bias,\n                                          activation = activation\n                                         )\n
"},{"location":"odak/learn_models/#odak.learn.models.residual_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    x0 = self.convolution(x)\n    return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.spatial_gate","title":"spatial_gate","text":"

Bases: Module

Spatial attention module that applies a convolution layer after channel pooling. This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

Source code in odak/learn/models/components.py
class spatial_gate(torch.nn.Module):\n    \"\"\"\n    Spatial attention module that applies a convolution layer after channel pooling.\n    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.\n    \"\"\"\n    def __init__(self):\n        \"\"\"\n        Initializes the spatial gate module.\n        \"\"\"\n        super().__init__()\n        kernel_size = 7\n        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n\n\n    def channel_pool(self, x):\n        \"\"\"\n        Applies max and average pooling on the channels.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input tensor.\n\n        Returns\n        -------\n        output        : torch.tensor\n                        Output tensor.\n        \"\"\"\n        max_pool = torch.max(x, 1)[0].unsqueeze(1)\n        avg_pool = torch.mean(x, 1).unsqueeze(1)\n        output = torch.cat((max_pool, avg_pool), dim=1)\n        return output\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the SpatialGate module.\n\n        Applies spatial attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the SpatialGate module.\n\n        Returns\n        -------\n        scaled_x     : torch.tensor\n                       Output tensor after applying spatial attention.\n        \"\"\"\n        x_compress = self.channel_pool(x)\n        x_out = self.spatial(x_compress)\n        scale = torch.sigmoid(x_out)\n        scaled_x = x * scale\n        return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.spatial_gate.__init__","title":"__init__()","text":"

Initializes the spatial gate module.

Source code in odak/learn/models/components.py
def __init__(self):\n    \"\"\"\n    Initializes the spatial gate module.\n    \"\"\"\n    super().__init__()\n    kernel_size = 7\n    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n
"},{"location":"odak/learn_models/#odak.learn.models.spatial_gate.channel_pool","title":"channel_pool(x)","text":"

Applies max and average pooling on the channels.

Parameters:

  • x \u2013
            Input tensor.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor.

Source code in odak/learn/models/components.py
def channel_pool(self, x):\n    \"\"\"\n    Applies max and average pooling on the channels.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input tensor.\n\n    Returns\n    -------\n    output        : torch.tensor\n                    Output tensor.\n    \"\"\"\n    max_pool = torch.max(x, 1)[0].unsqueeze(1)\n    avg_pool = torch.mean(x, 1).unsqueeze(1)\n    output = torch.cat((max_pool, avg_pool), dim=1)\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.spatial_gate.forward","title":"forward(x)","text":"

Forward pass of the SpatialGate module.

Applies spatial attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the SpatialGate module.\n

Returns:

  • scaled_x ( tensor ) \u2013

    Output tensor after applying spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the SpatialGate module.\n\n    Applies spatial attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the SpatialGate module.\n\n    Returns\n    -------\n    scaled_x     : torch.tensor\n                   Output tensor after applying spatial attention.\n    \"\"\"\n    x_compress = self.channel_pool(x)\n    x_out = self.spatial(x_compress)\n    scale = torch.sigmoid(x_out)\n    scaled_x = x * scale\n    return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_convolution","title":"spatially_adaptive_convolution","text":"

Bases: Module

A spatially adaptive convolution layer.

References

C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\" C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"

Source code in odak/learn/models/components.py
class spatially_adaptive_convolution(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive convolution layer.\n\n    References\n    ----------\n\n    C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\"\n    C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\"\n    C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn.Module\n                          Activation function to apply. If None, no activation is applied.\n        \"\"\"\n        super(spatially_adaptive_convolution, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        sa_output          : torch.tensor\n                            Estimated output tensor.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Resize weight to match the input channels and kernel size\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply spatially varying kernels\n        sv_feature = input_feature * sv_kernel\n\n        # Perform matrix multiplication\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n        return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_convolution.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive convolution layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Activation function to apply. If None, no activation is applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn.Module\n                      Activation function to apply. If None, no activation is applied.\n    \"\"\"\n    super(spatially_adaptive_convolution, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_convolution.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive convolution layer.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • sa_output ( tensor ) \u2013

    Estimated output tensor. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    sa_output          : torch.tensor\n                        Estimated output tensor.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Resize weight to match the input channels and kernel size\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply spatially varying kernels\n    sv_feature = input_feature * sv_kernel\n\n    # Perform matrix multiplication\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n    return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_module","title":"spatially_adaptive_module","text":"

Bases: Module

A spatially adaptive module that combines learned spatially adaptive convolutions.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/components.py
class spatially_adaptive_module(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive module that combines learned spatially adaptive convolutions.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive module.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super(spatially_adaptive_module, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.weight_output_channels = self.output_channels - 1\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.weight_output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive module.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        output             : torch.tensor\n                            Combined output tensor from standard and spatially adaptive convolutions.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Apply sv_kernel to the input_feature\n        sv_feature = input_feature * sv_kernel\n\n        # Original spatially varying convolution output\n        sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                           1,\n                                                            1,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n        # Reshape weight for spatially adaptive convolution\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply si_kernel on sv convolution output\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n\n        # Combine the outputs and apply activation function\n        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_module.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive module.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive module.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super(spatially_adaptive_module, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.weight_output_channels = self.output_channels - 1\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.weight_output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_module.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive module.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • output ( tensor ) \u2013

    Combined output tensor from standard and spatially adaptive convolutions. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive module.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    output             : torch.tensor\n                        Combined output tensor from standard and spatially adaptive convolutions.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Apply sv_kernel to the input_feature\n    sv_feature = input_feature * sv_kernel\n\n    # Original spatially varying convolution output\n    sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                       1,\n                                                        1,\n                                                        (x.size(-2) // self.stride),\n                                                        (x.size(-1) // self.stride)\n                                                       )\n\n    # Reshape weight for spatially adaptive convolution\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply si_kernel on sv convolution output\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n    # Combine the outputs and apply activation function\n    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_unet","title":"spatially_adaptive_unet","text":"

Bases: Module

Spatially varying U-Net model based on spatially adaptive convolution.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/models.py
class spatially_adaptive_unet(torch.nn.Module):\n    \"\"\"\n    Spatially varying U-Net model based on spatially adaptive convolution.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 depth=3,\n                 dimensions=8,\n                 input_channels=6,\n                 out_channels=6,\n                 kernel_size=3,\n                 bias=True,\n                 normalization=False,\n                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        out_channels   : int\n                         Number of output channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.out_channels = out_channels\n        self.inc = convolution_layer(\n                                     input_channels=input_channels,\n                                     output_channels=dimensions,\n                                     kernel_size=kernel_size,\n                                     bias=bias,\n                                     normalization=normalization,\n                                     activation=activation\n                                    )\n\n        self.encoder = torch.nn.ModuleList()\n        for i in range(self.depth + 1):  # Downsampling layers\n            down_in_channels = dimensions * (2 ** i)\n            down_out_channels = 2 * down_in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels=down_in_channels,\n                                                          mid_channels=down_in_channels,\n                                                          output_channels=down_in_channels,\n                                                          kernel_size=kernel_size,\n                                                          bias=bias,\n                                                          normalization=normalization,\n                                                          activation=activation\n                                                         )\n            sam = spatially_adaptive_module(\n                                            input_channels=down_in_channels,\n                                            output_channels=down_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                           )\n            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n        self.global_feature_module = torch.nn.ModuleList()\n        double_convolution_layer = double_convolution(\n                                                      input_channels=dimensions * (2 ** (depth + 1)),\n                                                      mid_channels=dimensions * (2 ** (depth + 1)),\n                                                      output_channels=dimensions * (2 ** (depth + 1)),\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        global_feature_layer = global_feature_module(\n                                                     input_channels=dimensions * (2 ** (depth + 1)),\n                                                     mid_channels=dimensions * (2 ** (depth + 1)),\n                                                     output_channels=dimensions * (2 ** (depth + 1)),\n                                                     kernel_size=kernel_size,\n                                                     bias=bias,\n                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                    )\n        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n        self.decoder = torch.nn.ModuleList()\n        for i in range(depth, -1, -1):\n            up_in_channels = dimensions * (2 ** (i + 1))\n            up_mid_channels = up_in_channels // 2\n            if i == 0:\n                up_out_channels = self.out_channels\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = torch.nn.Sequential(\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_mid_channels,\n                                      kernel_size=kernel_size,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=activation,\n                                     ),\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_out_channels,\n                                      kernel_size=1,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=None,\n                                     )\n                )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n            else:\n                up_out_channels = up_in_channels // 2\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = double_convolution(\n                                                input_channels=up_mid_channels,\n                                                mid_channels=up_mid_channels,\n                                                output_channels=up_out_channels,\n                                                kernel_size=kernel_size,\n                                                bias=bias,\n                                                normalization=normalization,\n                                                activation=activation,\n                                               )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, sv_kernel, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n\n        field     : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        target_field : torch.tensor\n                       Estimated output.\n                       Dimension: (1, 6, H, W)\n        \"\"\"\n        x = self.inc(field)\n        downsampling_outputs = [x]\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer[0](downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n            downsampling_outputs.append(sam_output)\n        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n        downsampling_outputs.append(global_feature)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate(self.decoder):\n            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n            x_up = up_layer[1](x_up)\n        result = x_up\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_unet.__init__","title":"__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 6 ) \u2013
             Number of input channels.\n
  • out_channels \u2013
             Number of output channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth=3,\n             dimensions=8,\n             input_channels=6,\n             out_channels=6,\n             kernel_size=3,\n             bias=True,\n             normalization=False,\n             activation=torch.nn.LeakyReLU(0.2, inplace=True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    out_channels   : int\n                     Number of output channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.out_channels = out_channels\n    self.inc = convolution_layer(\n                                 input_channels=input_channels,\n                                 output_channels=dimensions,\n                                 kernel_size=kernel_size,\n                                 bias=bias,\n                                 normalization=normalization,\n                                 activation=activation\n                                )\n\n    self.encoder = torch.nn.ModuleList()\n    for i in range(self.depth + 1):  # Downsampling layers\n        down_in_channels = dimensions * (2 ** i)\n        down_out_channels = 2 * down_in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels=down_in_channels,\n                                                      mid_channels=down_in_channels,\n                                                      output_channels=down_in_channels,\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        sam = spatially_adaptive_module(\n                                        input_channels=down_in_channels,\n                                        output_channels=down_out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                       )\n        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n    self.global_feature_module = torch.nn.ModuleList()\n    double_convolution_layer = double_convolution(\n                                                  input_channels=dimensions * (2 ** (depth + 1)),\n                                                  mid_channels=dimensions * (2 ** (depth + 1)),\n                                                  output_channels=dimensions * (2 ** (depth + 1)),\n                                                  kernel_size=kernel_size,\n                                                  bias=bias,\n                                                  normalization=normalization,\n                                                  activation=activation\n                                                 )\n    global_feature_layer = global_feature_module(\n                                                 input_channels=dimensions * (2 ** (depth + 1)),\n                                                 mid_channels=dimensions * (2 ** (depth + 1)),\n                                                 output_channels=dimensions * (2 ** (depth + 1)),\n                                                 kernel_size=kernel_size,\n                                                 bias=bias,\n                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                )\n    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n    self.decoder = torch.nn.ModuleList()\n    for i in range(depth, -1, -1):\n        up_in_channels = dimensions * (2 ** (i + 1))\n        up_mid_channels = up_in_channels // 2\n        if i == 0:\n            up_out_channels = self.out_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = torch.nn.Sequential(\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_mid_channels,\n                                  kernel_size=kernel_size,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=activation,\n                                 ),\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_out_channels,\n                                  kernel_size=1,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=None,\n                                 )\n            )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n        else:\n            up_out_channels = up_in_channels // 2\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels=up_mid_channels,\n                                            mid_channels=up_mid_channels,\n                                            output_channels=up_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            normalization=normalization,\n                                            activation=activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_adaptive_unet.forward","title":"forward(sv_kernel, field)","text":"

Forward model.

Parameters:

  • sv_kernel (list of torch.tensor) \u2013
        Learned spatially varying kernels.\n    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n    where C_i, H_i, and W_i represent the channel, height, and width\n    of each feature at a certain scale.\n
  • field \u2013
        Input field data.\n    Dimension: (1, 6, H, W)\n

Returns:

  • target_field ( tensor ) \u2013

    Estimated output. Dimension: (1, 6, H, W)

Source code in odak/learn/models/models.py
def forward(self, sv_kernel, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n\n    field     : torch.tensor\n                Input field data.\n                Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    target_field : torch.tensor\n                   Estimated output.\n                   Dimension: (1, 6, H, W)\n    \"\"\"\n    x = self.inc(field)\n    downsampling_outputs = [x]\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer[0](downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n        downsampling_outputs.append(sam_output)\n    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n    downsampling_outputs.append(global_feature)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate(self.decoder):\n        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n        x_up = up_layer[1](x_up)\n    result = x_up\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_varying_kernel_generation_model","title":"spatially_varying_kernel_generation_model","text":"

Bases: Module

Spatially_varying_kernel_generation_model revised from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Refer to: J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

Source code in odak/learn/models/models.py
class spatially_varying_kernel_generation_model(torch.nn.Module):\n    \"\"\"\n    Spatially_varying_kernel_generation_model revised from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Refer to:\n    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\n    \"\"\"\n\n    def __init__(\n                 self,\n                 depth = 3,\n                 dimensions = 8,\n                 input_channels = 7,\n                 kernel_size = 3,\n                 bias = True,\n                 normalization = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.inc = convolution_layer(\n                                     input_channels = input_channels,\n                                     output_channels = dimensions,\n                                     kernel_size = kernel_size,\n                                     bias = bias,\n                                     normalization = normalization,\n                                     activation = activation\n                                    )\n        self.encoder = torch.nn.ModuleList()\n        for i in range(depth + 1):  # downsampling layers\n            if i == 0:\n                in_channels = dimensions * (2 ** i)\n                out_channels = dimensions * (2 ** i)\n            elif i == depth:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = dimensions * (2 ** (i - 1))\n            else:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = 2 * in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels = in_channels,\n                                                          mid_channels = in_channels,\n                                                          output_channels = out_channels,\n                                                          kernel_size = kernel_size,\n                                                          bias = bias,\n                                                          normalization = normalization,\n                                                          activation = activation\n                                                         )\n            self.encoder.append(pooling_layer)\n            self.encoder.append(double_convolution_layer)\n        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n        for i in range(depth, -1, -1):\n            if i == 1:\n                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n            else:\n                svf_in_channels = 2 ** (self.depth + i) + 1\n            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n            svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n            spatially_varying_kernel_generation = torch.nn.ModuleList()\n            for j in range(i, -1, -1):\n                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n                spatially_varying_kernel_generation.append(pooling_layer)\n            kernel_generation_block = torch.nn.Sequential(\n                torch.nn.Conv2d(\n                                in_channels = svf_in_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_out_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n            )\n            spatially_varying_kernel_generation.append(kernel_generation_block)\n            self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n        self.decoder = torch.nn.ModuleList()\n        global_feature_layer = global_feature_module(  # global feature layer\n                                                     input_channels = dimensions * (2 ** (depth - 1)),\n                                                     mid_channels = dimensions * (2 ** (depth - 1)),\n                                                     output_channels = dimensions * (2 ** (depth - 1)),\n                                                     kernel_size = kernel_size,\n                                                     bias = bias,\n                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                    )\n        self.decoder.append(global_feature_layer)\n        for i in range(depth, 0, -1):\n            if i == 2:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels\n                up_mid_channels = up_in_channels\n            elif i == 1:\n                up_in_channels = dimensions * 2\n                up_out_channels = dimensions\n                up_mid_channels = up_out_channels\n            else:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels // 2\n                up_mid_channels = up_in_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels = up_in_channels,\n                                                            output_channels = up_mid_channels,\n                                                            kernel_size = 2,\n                                                            stride = 2,\n                                                            bias = bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels = up_mid_channels,\n                                            output_channels = up_out_channels,\n                                            kernel_size = kernel_size,\n                                            bias = bias,\n                                            normalization = normalization,\n                                            activation = activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, focal_surface, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        focal_surface : torch.tensor\n                        Input focal surface data.\n                        Dimension: (1, 1, H, W)\n\n        field         : torch.tensor\n                        Input field data.\n                        Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n        \"\"\"\n        x = self.inc(torch.cat((focal_surface, field), dim = 1))\n        downsampling_outputs = [focal_surface]\n        downsampling_outputs.append(x)\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        sv_kernels = []\n        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n            if i == 0:\n                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n                downsampling_outputs[-1] = global_feature\n                sv_feature = [global_feature, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                              sv_feature[3]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n            else:\n                x_up = up_layer[0](downsampling_outputs[-1],\n                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n                x_up = up_layer[1](x_up)\n                downsampling_outputs[-1] = x_up\n                sv_feature = [x_up, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                if i == 1:\n                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n        return sv_kernels\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_varying_kernel_generation_model.__init__","title":"__init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 7 ) \u2013
             Number of input channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth = 3,\n             dimensions = 8,\n             input_channels = 7,\n             kernel_size = 3,\n             bias = True,\n             normalization = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.inc = convolution_layer(\n                                 input_channels = input_channels,\n                                 output_channels = dimensions,\n                                 kernel_size = kernel_size,\n                                 bias = bias,\n                                 normalization = normalization,\n                                 activation = activation\n                                )\n    self.encoder = torch.nn.ModuleList()\n    for i in range(depth + 1):  # downsampling layers\n        if i == 0:\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** i)\n        elif i == depth:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = dimensions * (2 ** (i - 1))\n        else:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = 2 * in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels = in_channels,\n                                                      mid_channels = in_channels,\n                                                      output_channels = out_channels,\n                                                      kernel_size = kernel_size,\n                                                      bias = bias,\n                                                      normalization = normalization,\n                                                      activation = activation\n                                                     )\n        self.encoder.append(pooling_layer)\n        self.encoder.append(double_convolution_layer)\n    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n    for i in range(depth, -1, -1):\n        if i == 1:\n            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n        else:\n            svf_in_channels = 2 ** (self.depth + i) + 1\n        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n        svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n        spatially_varying_kernel_generation = torch.nn.ModuleList()\n        for j in range(i, -1, -1):\n            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n            spatially_varying_kernel_generation.append(pooling_layer)\n        kernel_generation_block = torch.nn.Sequential(\n            torch.nn.Conv2d(\n                            in_channels = svf_in_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_out_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n        )\n        spatially_varying_kernel_generation.append(kernel_generation_block)\n        self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n    self.decoder = torch.nn.ModuleList()\n    global_feature_layer = global_feature_module(  # global feature layer\n                                                 input_channels = dimensions * (2 ** (depth - 1)),\n                                                 mid_channels = dimensions * (2 ** (depth - 1)),\n                                                 output_channels = dimensions * (2 ** (depth - 1)),\n                                                 kernel_size = kernel_size,\n                                                 bias = bias,\n                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                )\n    self.decoder.append(global_feature_layer)\n    for i in range(depth, 0, -1):\n        if i == 2:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels\n            up_mid_channels = up_in_channels\n        elif i == 1:\n            up_in_channels = dimensions * 2\n            up_out_channels = dimensions\n            up_mid_channels = up_out_channels\n        else:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels // 2\n            up_mid_channels = up_in_channels\n        upsample_layer = upsample_convtranspose2d_layer(\n                                                        input_channels = up_in_channels,\n                                                        output_channels = up_mid_channels,\n                                                        kernel_size = 2,\n                                                        stride = 2,\n                                                        bias = bias,\n                                                       )\n        conv_layer = double_convolution(\n                                        input_channels = up_mid_channels,\n                                        output_channels = up_out_channels,\n                                        kernel_size = kernel_size,\n                                        bias = bias,\n                                        normalization = normalization,\n                                        activation = activation,\n                                       )\n        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_models/#odak.learn.models.spatially_varying_kernel_generation_model.forward","title":"forward(focal_surface, field)","text":"

Forward model.

Parameters:

  • focal_surface (tensor) \u2013
            Input focal surface data.\n        Dimension: (1, 1, H, W)\n
  • field \u2013
            Input field data.\n        Dimension: (1, 6, H, W)\n

Returns:

  • sv_kernel ( list of torch.tensor ) \u2013

    Learned spatially varying kernels. Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), where C_i, H_i, and W_i represent the channel, height, and width of each feature at a certain scale.

Source code in odak/learn/models/models.py
def forward(self, focal_surface, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    focal_surface : torch.tensor\n                    Input focal surface data.\n                    Dimension: (1, 1, H, W)\n\n    field         : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n    \"\"\"\n    x = self.inc(torch.cat((focal_surface, field), dim = 1))\n    downsampling_outputs = [focal_surface]\n    downsampling_outputs.append(x)\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    sv_kernels = []\n    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n        if i == 0:\n            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n            downsampling_outputs[-1] = global_feature\n            sv_feature = [global_feature, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                          sv_feature[3]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n        else:\n            x_up = up_layer[0](downsampling_outputs[-1],\n                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n            x_up = up_layer[1](x_up)\n            downsampling_outputs[-1] = x_up\n            sv_feature = [x_up, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            if i == 1:\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n    return sv_kernels\n
"},{"location":"odak/learn_models/#odak.learn.models.unet","title":"unet","text":"

Bases: Module

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

Source code in odak/learn/models/models.py
class unet(torch.nn.Module):\n    \"\"\"\n    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.\n    \"\"\"\n\n    def __init__(\n                 self, \n                 depth = 4,\n                 dimensions = 64, \n                 input_channels = 2, \n                 output_channels = 1, \n                 bilinear = False,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(inplace = True),\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth             : int\n                            Number of upsampling and downsampling\n        dimensions        : int\n                            Number of dimensions.\n        input_channels    : int\n                            Number of input channels.\n        output_channels   : int\n                            Number of output channels.\n        bilinear          : bool\n                            Uses bilinear upsampling in upsampling layers when set True.\n        bias              : bool\n                            Set True to let convolutional layers learn a bias term.\n        activation        : torch.nn\n                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n        \"\"\"\n        super(unet, self).__init__()\n        self.inc = double_convolution(\n                                      input_channels = input_channels,\n                                      mid_channels = dimensions,\n                                      output_channels = dimensions,\n                                      kernel_size = kernel_size,\n                                      bias = bias,\n                                      activation = activation\n                                     )      \n\n        self.downsampling_layers = torch.nn.ModuleList()\n        self.upsampling_layers = torch.nn.ModuleList()\n        for i in range(depth): # downsampling layers\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** (i + 1))\n            down_layer = downsample_layer(in_channels,\n                                            out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                            )\n            self.downsampling_layers.append(down_layer)      \n\n        for i in range(depth - 1, -1, -1):  # upsampling layers\n            up_in_channels = dimensions * (2 ** (i + 1))  \n            up_out_channels = dimensions * (2 ** i) \n            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n            self.upsampling_layers.append(up_layer)\n        self.outc = torch.nn.Conv2d(\n                                    dimensions, \n                                    output_channels,\n                                    kernel_size = kernel_size,\n                                    padding = kernel_size // 2,\n                                    bias = bias\n                                   )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        downsampling_outputs = [self.inc(x)]\n        for down_layer in self.downsampling_layers:\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate((self.upsampling_layers)):\n            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n        result = self.outc(x_up)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.unet.__init__","title":"__init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
                Number of upsampling and downsampling\n
  • dimensions \u2013
                Number of dimensions.\n
  • input_channels \u2013
                Number of input channels.\n
  • output_channels \u2013
                Number of output channels.\n
  • bilinear \u2013
                Uses bilinear upsampling in upsampling layers when set True.\n
  • bias \u2013
                Set True to let convolutional layers learn a bias term.\n
  • activation \u2013
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n
Source code in odak/learn/models/models.py
def __init__(\n             self, \n             depth = 4,\n             dimensions = 64, \n             input_channels = 2, \n             output_channels = 1, \n             bilinear = False,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(inplace = True),\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth             : int\n                        Number of upsampling and downsampling\n    dimensions        : int\n                        Number of dimensions.\n    input_channels    : int\n                        Number of input channels.\n    output_channels   : int\n                        Number of output channels.\n    bilinear          : bool\n                        Uses bilinear upsampling in upsampling layers when set True.\n    bias              : bool\n                        Set True to let convolutional layers learn a bias term.\n    activation        : torch.nn\n                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n    \"\"\"\n    super(unet, self).__init__()\n    self.inc = double_convolution(\n                                  input_channels = input_channels,\n                                  mid_channels = dimensions,\n                                  output_channels = dimensions,\n                                  kernel_size = kernel_size,\n                                  bias = bias,\n                                  activation = activation\n                                 )      \n\n    self.downsampling_layers = torch.nn.ModuleList()\n    self.upsampling_layers = torch.nn.ModuleList()\n    for i in range(depth): # downsampling layers\n        in_channels = dimensions * (2 ** i)\n        out_channels = dimensions * (2 ** (i + 1))\n        down_layer = downsample_layer(in_channels,\n                                        out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                        )\n        self.downsampling_layers.append(down_layer)      \n\n    for i in range(depth - 1, -1, -1):  # upsampling layers\n        up_in_channels = dimensions * (2 ** (i + 1))  \n        up_out_channels = dimensions * (2 ** i) \n        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n        self.upsampling_layers.append(up_layer)\n    self.outc = torch.nn.Conv2d(\n                                dimensions, \n                                output_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               )\n
"},{"location":"odak/learn_models/#odak.learn.models.unet.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    downsampling_outputs = [self.inc(x)]\n    for down_layer in self.downsampling_layers:\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate((self.upsampling_layers)):\n        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n    result = self.outc(x_up)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_convtranspose2d_layer","title":"upsample_convtranspose2d_layer","text":"

Bases: Module

An upsampling convtranspose2d layer.

Source code in odak/learn/models/components.py
class upsample_convtranspose2d_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convtranspose2d layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 2,\n                 stride = 2,\n                 bias = False,\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super().__init__()\n        self.up = torch.nn.ConvTranspose2d(\n                                           in_channels = input_channels,\n                                           out_channels = output_channels,\n                                           bias = bias,\n                                           kernel_size = kernel_size,\n                                           stride = stride\n                                          )\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\"\n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        result = x1 + x2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_convtranspose2d_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 2,\n             stride = 2,\n             bias = False,\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super().__init__()\n    self.up = torch.nn.ConvTranspose2d(\n                                       in_channels = input_channels,\n                                       out_channels = output_channels,\n                                       bias = bias,\n                                       kernel_size = kernel_size,\n                                       stride = stride\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_convtranspose2d_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\"\n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    result = x1 + x2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_layer","title":"upsample_layer","text":"

Bases: Module

An upsampling convolutional layer.

Source code in odak/learn/models/components.py
class upsample_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convolutional layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(),\n                 bilinear = True\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        bilinear        : bool\n                          If set to True, bilinear sampling is used.\n        \"\"\"\n        super(upsample_layer, self).__init__()\n        if bilinear:\n            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n            self.conv = double_convolution(\n                                           input_channels = input_channels + output_channels,\n                                           mid_channels = input_channels // 2,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n        else:\n            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n            self.conv = double_convolution(\n                                           input_channels = input_channels,\n                                           mid_channels = output_channels,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\" \n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        x = torch.cat([x2, x1], dim = 1)\n        result = self.conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
  • bilinear \u2013
              If set to True, bilinear sampling is used.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(),\n             bilinear = True\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    bilinear        : bool\n                      If set to True, bilinear sampling is used.\n    \"\"\"\n    super(upsample_layer, self).__init__()\n    if bilinear:\n        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n        self.conv = double_convolution(\n                                       input_channels = input_channels + output_channels,\n                                       mid_channels = input_channels // 2,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n    else:\n        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n        self.conv = double_convolution(\n                                       input_channels = input_channels,\n                                       mid_channels = output_channels,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.upsample_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\" \n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    x = torch.cat([x2, x1], dim = 1)\n    result = self.conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.gaussian","title":"gaussian(x, multiplier=1.0)","text":"

A Gaussian non-linear activation. For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

Parameters:

  • x \u2013
           Input data.\n
  • multiplier \u2013
           Multiplier.\n

Returns:

  • result ( float or tensor ) \u2013

    Ouput data.

Source code in odak/learn/models/components.py
def gaussian(x, multiplier = 1.):\n    \"\"\"\n    A Gaussian non-linear activation.\n    For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n\n    Parameters\n    ----------\n    x            : float or torch.tensor\n                   Input data.\n    multiplier   : float or torch.tensor\n                   Multiplier.\n\n    Returns\n    -------\n    result       : float or torch.tensor\n                   Ouput data.\n    \"\"\"\n    result = torch.exp(- (multiplier * x) ** 2)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.swish","title":"swish(x)","text":"

A swish non-linear activation. For more details: https://en.wikipedia.org/wiki/Swish_function

Parameters:

  • x \u2013
             Input.\n

Returns:

  • out ( float or tensor ) \u2013

    Output.

Source code in odak/learn/models/components.py
def swish(x):\n    \"\"\"\n    A swish non-linear activation.\n    For more details: https://en.wikipedia.org/wiki/Swish_function\n\n    Parameters\n    -----------\n    x              : float or torch.tensor\n                     Input.\n\n    Returns\n    -------\n    out            : float or torch.tensor\n                     Output.\n    \"\"\"\n    out = x * torch.sigmoid(x)\n    return out\n
"},{"location":"odak/learn_models/#odak.learn.models.components.channel_gate","title":"channel_gate","text":"

Bases: Module

Channel attention module with various pooling strategies. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class channel_gate(torch.nn.Module):\n    \"\"\"\n    Channel attention module with various pooling strategies.\n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max']\n                ):\n        \"\"\"\n        Initializes the channel gate module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the intermediate layer.\n        pool_types      : list\n                          List of pooling operations to apply.\n        \"\"\"\n        super().__init__()\n        self.gate_channels = gate_channels\n        hidden_channels = gate_channels // reduction_ratio\n        if hidden_channels == 0:\n            hidden_channels = 1\n        self.mlp = torch.nn.Sequential(\n                                       convolutional_block_attention.Flatten(),\n                                       torch.nn.Linear(gate_channels, hidden_channels),\n                                       torch.nn.ReLU(),\n                                       torch.nn.Linear(hidden_channels, gate_channels)\n                                      )\n        self.pool_types = pool_types\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the ChannelGate module.\n\n        Applies channel-wise attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the ChannelGate module.\n\n        Returns\n        -------\n        output       : torch.tensor\n                       Output tensor after applying channel attention.\n        \"\"\"\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type == 'avg':\n                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            elif pool_type == 'max':\n                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            channel_att_raw = self.mlp(pool)\n            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n        output = x * scale\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.channel_gate.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'])","text":"

Initializes the channel gate module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the intermediate layer.\n
  • pool_types \u2013
              List of pooling operations to apply.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max']\n            ):\n    \"\"\"\n    Initializes the channel gate module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the intermediate layer.\n    pool_types      : list\n                      List of pooling operations to apply.\n    \"\"\"\n    super().__init__()\n    self.gate_channels = gate_channels\n    hidden_channels = gate_channels // reduction_ratio\n    if hidden_channels == 0:\n        hidden_channels = 1\n    self.mlp = torch.nn.Sequential(\n                                   convolutional_block_attention.Flatten(),\n                                   torch.nn.Linear(gate_channels, hidden_channels),\n                                   torch.nn.ReLU(),\n                                   torch.nn.Linear(hidden_channels, gate_channels)\n                                  )\n    self.pool_types = pool_types\n
"},{"location":"odak/learn_models/#odak.learn.models.components.channel_gate.forward","title":"forward(x)","text":"

Forward pass of the ChannelGate module.

Applies channel-wise attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the ChannelGate module.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor after applying channel attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the ChannelGate module.\n\n    Applies channel-wise attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the ChannelGate module.\n\n    Returns\n    -------\n    output       : torch.tensor\n                   Output tensor after applying channel attention.\n    \"\"\"\n    channel_att_sum = None\n    for pool_type in self.pool_types:\n        if pool_type == 'avg':\n            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        elif pool_type == 'max':\n            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        channel_att_raw = self.mlp(pool)\n        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n    output = x * scale\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolution_layer","title":"convolution_layer","text":"

Bases: Module

A convolution layer.

Source code in odak/learn/models/components.py
class convolution_layer(torch.nn.Module):\n    \"\"\"\n    A convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 bias = False,\n                 stride = 1,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        layers = [\n            torch.nn.Conv2d(\n                            input_channels,\n                            output_channels,\n                            kernel_size = kernel_size,\n                            stride = stride,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           )\n        ]\n        if normalization:\n            layers.append(torch.nn.BatchNorm2d(output_channels))\n        if activation:\n            layers.append(activation)\n        self.model = torch.nn.Sequential(*layers)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolution_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             bias = False,\n             stride = 1,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    layers = [\n        torch.nn.Conv2d(\n                        input_channels,\n                        output_channels,\n                        kernel_size = kernel_size,\n                        stride = stride,\n                        padding = kernel_size // 2,\n                        bias = bias\n                       )\n    ]\n    if normalization:\n        layers.append(torch.nn.BatchNorm2d(output_channels))\n    if activation:\n        layers.append(activation)\n    self.model = torch.nn.Sequential(*layers)\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolution_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolutional_block_attention","title":"convolutional_block_attention","text":"

Bases: Module

Convolutional Block Attention Module (CBAM) class. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class convolutional_block_attention(torch.nn.Module):\n    \"\"\"\n    Convolutional Block Attention Module (CBAM) class. \n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max'], \n                 no_spatial = False\n                ):\n        \"\"\"\n        Initializes the convolutional block attention module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the channel attention.\n        pool_types      : list\n                          List of pooling operations to apply for channel attention.\n        no_spatial      : bool\n                          If True, spatial attention is not applied.\n        \"\"\"\n        super(convolutional_block_attention, self).__init__()\n        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial = no_spatial\n        if not no_spatial:\n            self.spatial_gate = spatial_gate()\n\n\n    class Flatten(torch.nn.Module):\n        \"\"\"\n        Flattens the input tensor to a 2D matrix.\n        \"\"\"\n        def forward(self, x):\n            return x.view(x.size(0), -1)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the convolutional block attention module.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the CBAM module.\n\n        Returns\n        -------\n        x_out        : torch.tensor\n                       Output tensor after applying channel and spatial attention.\n        \"\"\"\n        x_out = self.channel_gate(x)\n        if not self.no_spatial:\n            x_out = self.spatial_gate(x_out)\n        return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolutional_block_attention.Flatten","title":"Flatten","text":"

Bases: Module

Flattens the input tensor to a 2D matrix.

Source code in odak/learn/models/components.py
class Flatten(torch.nn.Module):\n    \"\"\"\n    Flattens the input tensor to a 2D matrix.\n    \"\"\"\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolutional_block_attention.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False)","text":"

Initializes the convolutional block attention module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the channel attention.\n
  • pool_types \u2013
              List of pooling operations to apply for channel attention.\n
  • no_spatial \u2013
              If True, spatial attention is not applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max'], \n             no_spatial = False\n            ):\n    \"\"\"\n    Initializes the convolutional block attention module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the channel attention.\n    pool_types      : list\n                      List of pooling operations to apply for channel attention.\n    no_spatial      : bool\n                      If True, spatial attention is not applied.\n    \"\"\"\n    super(convolutional_block_attention, self).__init__()\n    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n    self.no_spatial = no_spatial\n    if not no_spatial:\n        self.spatial_gate = spatial_gate()\n
"},{"location":"odak/learn_models/#odak.learn.models.components.convolutional_block_attention.forward","title":"forward(x)","text":"

Forward pass of the convolutional block attention module.

Parameters:

  • x \u2013
           Input tensor to the CBAM module.\n

Returns:

  • x_out ( tensor ) \u2013

    Output tensor after applying channel and spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the convolutional block attention module.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the CBAM module.\n\n    Returns\n    -------\n    x_out        : torch.tensor\n                   Output tensor after applying channel and spatial attention.\n    \"\"\"\n    x_out = self.channel_gate(x)\n    if not self.no_spatial:\n        x_out = self.spatial_gate(x_out)\n    return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.components.double_convolution","title":"double_convolution","text":"

Bases: Module

A double convolution layer.

Source code in odak/learn/models/components.py
class double_convolution(torch.nn.Module):\n    \"\"\"\n    A double convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = None,\n                 output_channels = 2,\n                 kernel_size = 3, \n                 bias = False,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        Double convolution model.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of channels in the hidden layer between two convolutions.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        if isinstance(mid_channels, type(None)):\n            mid_channels = output_channels\n        self.activation = activation\n        self.model = torch.nn.Sequential(\n                                         convolution_layer(\n                                                           input_channels = input_channels,\n                                                           output_channels = mid_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          ),\n                                         convolution_layer(\n                                                           input_channels = mid_channels,\n                                                           output_channels = output_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          )\n                                        )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.double_convolution.__init__","title":"__init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU())","text":"

Double convolution model.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of channels in the hidden layer between two convolutions.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = None,\n             output_channels = 2,\n             kernel_size = 3, \n             bias = False,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    Double convolution model.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of channels in the hidden layer between two convolutions.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    if isinstance(mid_channels, type(None)):\n        mid_channels = output_channels\n    self.activation = activation\n    self.model = torch.nn.Sequential(\n                                     convolution_layer(\n                                                       input_channels = input_channels,\n                                                       output_channels = mid_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      ),\n                                     convolution_layer(\n                                                       input_channels = mid_channels,\n                                                       output_channels = output_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      )\n                                    )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.double_convolution.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.downsample_layer","title":"downsample_layer","text":"

Bases: Module

A downscaling component followed by a double convolution.

Source code in odak/learn/models/components.py
class downsample_layer(torch.nn.Module):\n    \"\"\"\n    A downscaling component followed by a double convolution.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.maxpool_conv = torch.nn.Sequential(\n                                                torch.nn.MaxPool2d(2),\n                                                double_convolution(\n                                                                   input_channels = input_channels,\n                                                                   mid_channels = output_channels,\n                                                                   output_channels = output_channels,\n                                                                   kernel_size = kernel_size,\n                                                                   bias = bias,\n                                                                   activation = activation\n                                                                  )\n                                               )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x              : torch.tensor\n                         First input data.\n\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.maxpool_conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.downsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.maxpool_conv = torch.nn.Sequential(\n                                            torch.nn.MaxPool2d(2),\n                                            double_convolution(\n                                                               input_channels = input_channels,\n                                                               mid_channels = output_channels,\n                                                               output_channels = output_channels,\n                                                               kernel_size = kernel_size,\n                                                               bias = bias,\n                                                               activation = activation\n                                                              )\n                                           )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.downsample_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
             First input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x              : torch.tensor\n                     First input data.\n\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.maxpool_conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_feature_module","title":"global_feature_module","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies them to another input tensor via learned transformations.

Source code in odak/learn/models/components.py
class global_feature_module(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies them to another input tensor via learned transformations.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 mid_channels,\n                 output_channels,\n                 kernel_size,\n                 bias = False,\n                 normalization = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels  : int\n                          Number of mid channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.transformations_1 = global_transformations(input_channels, output_channels)\n        self.global_features_1 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.global_features_2 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.transformations_2 = global_transformations(input_channels, output_channels)\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        global_tensor_1 = self.transformations_1(x1, x2)\n        y1 = self.global_features_1(global_tensor_1)\n        y2 = self.global_features_2(y1)\n        global_tensor_2 = self.transformations_2(y1, y2)\n        return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_feature_module.__init__","title":"__init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU())","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of mid channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             mid_channels,\n             output_channels,\n             kernel_size,\n             bias = False,\n             normalization = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels  : int\n                      Number of mid channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.transformations_1 = global_transformations(input_channels, output_channels)\n    self.global_features_1 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.global_features_2 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.transformations_2 = global_transformations(input_channels, output_channels)\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_feature_module.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    global_tensor_1 = self.transformations_1(x1, x2)\n    y1 = self.global_features_1(global_tensor_1)\n    y2 = self.global_features_2(y1)\n    global_tensor_2 = self.transformations_2(y1, y2)\n    return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_transformations","title":"global_transformations","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies learned transformations to another input tensor.

This implementation is adapted from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Reference: J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"

Source code in odak/learn/models/components.py
class global_transformations(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies learned transformations to another input tensor.\n\n    This implementation is adapted from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Reference:\n    J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        \"\"\"\n        super().__init__()\n        self.global_feature_1 = torch.nn.Sequential(\n            torch.nn.Linear(input_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True),\n        )\n        self.global_feature_2 = torch.nn.Sequential(\n            torch.nn.Linear(output_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True)\n        )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        y = torch.mean(x2, dim = (2, 3))\n        y1 = self.global_feature_1(y)\n        y2 = self.global_feature_2(y1)\n        y1 = y1.unsqueeze(2).unsqueeze(3)\n        y2 = y2.unsqueeze(2).unsqueeze(3)\n        result = x1 * y1 + y2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_transformations.__init__","title":"__init__(input_channels, output_channels)","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    \"\"\"\n    super().__init__()\n    self.global_feature_1 = torch.nn.Sequential(\n        torch.nn.Linear(input_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True),\n    )\n    self.global_feature_2 = torch.nn.Sequential(\n        torch.nn.Linear(output_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True)\n    )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.global_transformations.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    y = torch.mean(x2, dim = (2, 3))\n    y1 = self.global_feature_1(y)\n    y2 = self.global_feature_2(y1)\n    y1 = y1.unsqueeze(2).unsqueeze(3)\n    y2 = y2.unsqueeze(2).unsqueeze(3)\n    result = x1 * y1 + y2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.non_local_layer","title":"non_local_layer","text":"

Bases: Module

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

Source code in odak/learn/models/components.py
class non_local_layer(torch.nn.Module):\n    \"\"\"\n    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 1024,\n                 bottleneck_channels = 512,\n                 kernel_size = 1,\n                 bias = False,\n                ):\n        \"\"\"\n\n        Parameters\n        ----------\n        input_channels      : int\n                              Number of input channels.\n        bottleneck_channels : int\n                              Number of middle channels.\n        kernel_size         : int\n                              Kernel size.\n        bias                : bool \n                              Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super(non_local_layer, self).__init__()\n        self.input_channels = input_channels\n        self.bottleneck_channels = bottleneck_channels\n        self.g = torch.nn.Conv2d(\n                                 self.input_channels, \n                                 self.bottleneck_channels,\n                                 kernel_size = kernel_size,\n                                 padding = kernel_size // 2,\n                                 bias = bias\n                                )\n        self.W_z = torch.nn.Sequential(\n                                       torch.nn.Conv2d(\n                                                       self.bottleneck_channels,\n                                                       self.input_channels, \n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       padding = kernel_size // 2\n                                                      ),\n                                       torch.nn.BatchNorm2d(self.input_channels)\n                                      )\n        torch.nn.init.constant_(self.W_z[1].weight, 0)   \n        torch.nn.init.constant_(self.W_z[1].bias, 0)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model [zi = Wzyi + xi]\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          First input data.                       \n\n\n        Returns\n        ----------\n        z               : torch.tensor\n                          Estimated output.\n        \"\"\"\n        batch_size, channels, height, width = x.size()\n        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n        attn = torch.nn.functional.softmax(attn, dim=-1)\n        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n        W_y = self.W_z(y)\n        z = W_y + x\n        return z\n
"},{"location":"odak/learn_models/#odak.learn.models.components.non_local_layer.__init__","title":"__init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False)","text":"

Parameters:

  • input_channels \u2013
                  Number of input channels.\n
  • bottleneck_channels (int, default: 512 ) \u2013
                  Number of middle channels.\n
  • kernel_size \u2013
                  Kernel size.\n
  • bias \u2013
                  Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 1024,\n             bottleneck_channels = 512,\n             kernel_size = 1,\n             bias = False,\n            ):\n    \"\"\"\n\n    Parameters\n    ----------\n    input_channels      : int\n                          Number of input channels.\n    bottleneck_channels : int\n                          Number of middle channels.\n    kernel_size         : int\n                          Kernel size.\n    bias                : bool \n                          Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super(non_local_layer, self).__init__()\n    self.input_channels = input_channels\n    self.bottleneck_channels = bottleneck_channels\n    self.g = torch.nn.Conv2d(\n                             self.input_channels, \n                             self.bottleneck_channels,\n                             kernel_size = kernel_size,\n                             padding = kernel_size // 2,\n                             bias = bias\n                            )\n    self.W_z = torch.nn.Sequential(\n                                   torch.nn.Conv2d(\n                                                   self.bottleneck_channels,\n                                                   self.input_channels, \n                                                   kernel_size = kernel_size,\n                                                   bias = bias,\n                                                   padding = kernel_size // 2\n                                                  ),\n                                   torch.nn.BatchNorm2d(self.input_channels)\n                                  )\n    torch.nn.init.constant_(self.W_z[1].weight, 0)   \n    torch.nn.init.constant_(self.W_z[1].bias, 0)\n
"},{"location":"odak/learn_models/#odak.learn.models.components.non_local_layer.forward","title":"forward(x)","text":"

Forward model [zi = Wzyi + xi]

Parameters:

  • x \u2013
              First input data.\n

Returns:

  • z ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model [zi = Wzyi + xi]\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      First input data.                       \n\n\n    Returns\n    ----------\n    z               : torch.tensor\n                      Estimated output.\n    \"\"\"\n    batch_size, channels, height, width = x.size()\n    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n    attn = torch.nn.functional.softmax(attn, dim=-1)\n    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n    W_y = self.W_z(y)\n    z = W_y + x\n    return z\n
"},{"location":"odak/learn_models/#odak.learn.models.components.normalization","title":"normalization","text":"

Bases: Module

A normalization layer.

Source code in odak/learn/models/components.py
class normalization(torch.nn.Module):\n    \"\"\"\n    A normalization layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 dim = 1,\n                ):\n        \"\"\"\n        Normalization layer.\n\n\n        Parameters\n        ----------\n        dim             : int\n                          Dimension (axis) to normalize.\n        \"\"\"\n        super().__init__()\n        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = 1, keepdim = True)\n        result =  (x - mean) * (var + eps).rsqrt() * self.k\n        return result \n
"},{"location":"odak/learn_models/#odak.learn.models.components.normalization.__init__","title":"__init__(dim=1)","text":"

Normalization layer.

Parameters:

  • dim \u2013
              Dimension (axis) to normalize.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             dim = 1,\n            ):\n    \"\"\"\n    Normalization layer.\n\n\n    Parameters\n    ----------\n    dim             : int\n                      Dimension (axis) to normalize.\n    \"\"\"\n    super().__init__()\n    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n
"},{"location":"odak/learn_models/#odak.learn.models.components.normalization.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n    mean = torch.mean(x, dim = 1, keepdim = True)\n    result =  (x - mean) * (var + eps).rsqrt() * self.k\n    return result \n
"},{"location":"odak/learn_models/#odak.learn.models.components.positional_encoder","title":"positional_encoder","text":"

Bases: Module

A positional encoder module.

Source code in odak/learn/models/components.py
class positional_encoder(torch.nn.Module):\n    \"\"\"\n    A positional encoder module.\n    \"\"\"\n\n    def __init__(self, L):\n        \"\"\"\n        A positional encoder module.\n\n        Parameters\n        ----------\n        L                   : int\n                              Positional encoding level.\n        \"\"\"\n        super(positional_encoder, self).__init__()\n        self.L = L\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          Input data.\n\n        Returns\n        ----------\n        result          : torch.tensor\n                          Result of the forward operation\n        \"\"\"\n        B, C = x.shape\n        x = x.view(B, C, 1)\n        results = [x]\n        for i in range(1, self.L + 1):\n            freq = (2 ** i) * math.pi\n            cos_x = torch.cos(freq * x)\n            sin_x = torch.sin(freq * x)\n            results.append(cos_x)\n            results.append(sin_x)\n        results = torch.cat(results, dim=2)\n        results = results.permute(0, 2, 1)\n        results = results.reshape(B, -1)\n        return results\n
"},{"location":"odak/learn_models/#odak.learn.models.components.positional_encoder.__init__","title":"__init__(L)","text":"

A positional encoder module.

Parameters:

  • L \u2013
                  Positional encoding level.\n
Source code in odak/learn/models/components.py
def __init__(self, L):\n    \"\"\"\n    A positional encoder module.\n\n    Parameters\n    ----------\n    L                   : int\n                          Positional encoding level.\n    \"\"\"\n    super(positional_encoder, self).__init__()\n    self.L = L\n
"},{"location":"odak/learn_models/#odak.learn.models.components.positional_encoder.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
              Input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      Input data.\n\n    Returns\n    ----------\n    result          : torch.tensor\n                      Result of the forward operation\n    \"\"\"\n    B, C = x.shape\n    x = x.view(B, C, 1)\n    results = [x]\n    for i in range(1, self.L + 1):\n        freq = (2 ** i) * math.pi\n        cos_x = torch.cos(freq * x)\n        sin_x = torch.sin(freq * x)\n        results.append(cos_x)\n        results.append(sin_x)\n    results = torch.cat(results, dim=2)\n    results = results.permute(0, 2, 1)\n    results = results.reshape(B, -1)\n    return results\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_attention_layer","title":"residual_attention_layer","text":"

Bases: Module

A residual block with an attention layer.

Source code in odak/learn/models/components.py
class residual_attention_layer(torch.nn.Module):\n    \"\"\"\n    A residual block with an attention layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 1,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        An attention layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int or optioal\n                          Number of input channels.\n        output_channels : int or optional\n                          Number of middle channels.\n        kernel_size     : int or optional\n                          Kernel size.\n        bias            : bool or optional\n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn or optional\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution0 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.convolution1 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.final_layer = torch.nn.Sequential(\n                                               self.activation,\n                                               torch.nn.Conv2d(\n                                                               output_channels,\n                                                               output_channels,\n                                                               kernel_size = kernel_size,\n                                                               padding = kernel_size // 2,\n                                                               bias = bias\n                                                              )\n                                              )\n\n\n    def forward(self, x0, x1):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x0             : torch.tensor\n                         First input data.\n\n        x1             : torch.tensor\n                         Seconnd input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        y0 = self.convolution0(x0)\n        y1 = self.convolution1(x1)\n        y2 = torch.add(y0, y1)\n        result = self.final_layer(y2) * x0\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_attention_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU())","text":"

An attention layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int or optional, default: 2 ) \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 1,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    An attention layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int or optioal\n                      Number of input channels.\n    output_channels : int or optional\n                      Number of middle channels.\n    kernel_size     : int or optional\n                      Kernel size.\n    bias            : bool or optional\n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn or optional\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution0 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.convolution1 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.final_layer = torch.nn.Sequential(\n                                           self.activation,\n                                           torch.nn.Conv2d(\n                                                           output_channels,\n                                                           output_channels,\n                                                           kernel_size = kernel_size,\n                                                           padding = kernel_size // 2,\n                                                           bias = bias\n                                                          )\n                                          )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_attention_layer.forward","title":"forward(x0, x1)","text":"

Forward model.

Parameters:

  • x0 \u2013
             First input data.\n
  • x1 \u2013
             Seconnd input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x0, x1):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x0             : torch.tensor\n                     First input data.\n\n    x1             : torch.tensor\n                     Seconnd input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    y0 = self.convolution0(x0)\n    y1 = self.convolution1(x1)\n    y2 = torch.add(y0, y1)\n    result = self.final_layer(y2) * x0\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_layer","title":"residual_layer","text":"

Bases: Module

A residual layer.

Source code in odak/learn/models/components.py
class residual_layer(torch.nn.Module):\n    \"\"\"\n    A residual layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = 16,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of middle channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution = double_convolution(\n                                              input_channels,\n                                              mid_channels = mid_channels,\n                                              output_channels = input_channels,\n                                              kernel_size = kernel_size,\n                                              bias = bias,\n                                              activation = activation\n                                             )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        x0 = self.convolution(x)\n        return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_layer.__init__","title":"__init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = 16,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of middle channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution = double_convolution(\n                                          input_channels,\n                                          mid_channels = mid_channels,\n                                          output_channels = input_channels,\n                                          kernel_size = kernel_size,\n                                          bias = bias,\n                                          activation = activation\n                                         )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.residual_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    x0 = self.convolution(x)\n    return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatial_gate","title":"spatial_gate","text":"

Bases: Module

Spatial attention module that applies a convolution layer after channel pooling. This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

Source code in odak/learn/models/components.py
class spatial_gate(torch.nn.Module):\n    \"\"\"\n    Spatial attention module that applies a convolution layer after channel pooling.\n    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.\n    \"\"\"\n    def __init__(self):\n        \"\"\"\n        Initializes the spatial gate module.\n        \"\"\"\n        super().__init__()\n        kernel_size = 7\n        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n\n\n    def channel_pool(self, x):\n        \"\"\"\n        Applies max and average pooling on the channels.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input tensor.\n\n        Returns\n        -------\n        output        : torch.tensor\n                        Output tensor.\n        \"\"\"\n        max_pool = torch.max(x, 1)[0].unsqueeze(1)\n        avg_pool = torch.mean(x, 1).unsqueeze(1)\n        output = torch.cat((max_pool, avg_pool), dim=1)\n        return output\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the SpatialGate module.\n\n        Applies spatial attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the SpatialGate module.\n\n        Returns\n        -------\n        scaled_x     : torch.tensor\n                       Output tensor after applying spatial attention.\n        \"\"\"\n        x_compress = self.channel_pool(x)\n        x_out = self.spatial(x_compress)\n        scale = torch.sigmoid(x_out)\n        scaled_x = x * scale\n        return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatial_gate.__init__","title":"__init__()","text":"

Initializes the spatial gate module.

Source code in odak/learn/models/components.py
def __init__(self):\n    \"\"\"\n    Initializes the spatial gate module.\n    \"\"\"\n    super().__init__()\n    kernel_size = 7\n    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatial_gate.channel_pool","title":"channel_pool(x)","text":"

Applies max and average pooling on the channels.

Parameters:

  • x \u2013
            Input tensor.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor.

Source code in odak/learn/models/components.py
def channel_pool(self, x):\n    \"\"\"\n    Applies max and average pooling on the channels.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input tensor.\n\n    Returns\n    -------\n    output        : torch.tensor\n                    Output tensor.\n    \"\"\"\n    max_pool = torch.max(x, 1)[0].unsqueeze(1)\n    avg_pool = torch.mean(x, 1).unsqueeze(1)\n    output = torch.cat((max_pool, avg_pool), dim=1)\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatial_gate.forward","title":"forward(x)","text":"

Forward pass of the SpatialGate module.

Applies spatial attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the SpatialGate module.\n

Returns:

  • scaled_x ( tensor ) \u2013

    Output tensor after applying spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the SpatialGate module.\n\n    Applies spatial attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the SpatialGate module.\n\n    Returns\n    -------\n    scaled_x     : torch.tensor\n                   Output tensor after applying spatial attention.\n    \"\"\"\n    x_compress = self.channel_pool(x)\n    x_out = self.spatial(x_compress)\n    scale = torch.sigmoid(x_out)\n    scaled_x = x * scale\n    return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_convolution","title":"spatially_adaptive_convolution","text":"

Bases: Module

A spatially adaptive convolution layer.

References

C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\" C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"

Source code in odak/learn/models/components.py
class spatially_adaptive_convolution(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive convolution layer.\n\n    References\n    ----------\n\n    C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\"\n    C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\"\n    C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn.Module\n                          Activation function to apply. If None, no activation is applied.\n        \"\"\"\n        super(spatially_adaptive_convolution, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        sa_output          : torch.tensor\n                            Estimated output tensor.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Resize weight to match the input channels and kernel size\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply spatially varying kernels\n        sv_feature = input_feature * sv_kernel\n\n        # Perform matrix multiplication\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n        return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_convolution.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive convolution layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Activation function to apply. If None, no activation is applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn.Module\n                      Activation function to apply. If None, no activation is applied.\n    \"\"\"\n    super(spatially_adaptive_convolution, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_convolution.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive convolution layer.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • sa_output ( tensor ) \u2013

    Estimated output tensor. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    sa_output          : torch.tensor\n                        Estimated output tensor.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Resize weight to match the input channels and kernel size\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply spatially varying kernels\n    sv_feature = input_feature * sv_kernel\n\n    # Perform matrix multiplication\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n    return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_module","title":"spatially_adaptive_module","text":"

Bases: Module

A spatially adaptive module that combines learned spatially adaptive convolutions.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/components.py
class spatially_adaptive_module(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive module that combines learned spatially adaptive convolutions.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive module.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super(spatially_adaptive_module, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.weight_output_channels = self.output_channels - 1\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.weight_output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive module.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        output             : torch.tensor\n                            Combined output tensor from standard and spatially adaptive convolutions.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Apply sv_kernel to the input_feature\n        sv_feature = input_feature * sv_kernel\n\n        # Original spatially varying convolution output\n        sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                           1,\n                                                            1,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n        # Reshape weight for spatially adaptive convolution\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply si_kernel on sv convolution output\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n\n        # Combine the outputs and apply activation function\n        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_module.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive module.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive module.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super(spatially_adaptive_module, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.weight_output_channels = self.output_channels - 1\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.weight_output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.components.spatially_adaptive_module.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive module.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • output ( tensor ) \u2013

    Combined output tensor from standard and spatially adaptive convolutions. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive module.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    output             : torch.tensor\n                        Combined output tensor from standard and spatially adaptive convolutions.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Apply sv_kernel to the input_feature\n    sv_feature = input_feature * sv_kernel\n\n    # Original spatially varying convolution output\n    sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                       1,\n                                                        1,\n                                                        (x.size(-2) // self.stride),\n                                                        (x.size(-1) // self.stride)\n                                                       )\n\n    # Reshape weight for spatially adaptive convolution\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply si_kernel on sv convolution output\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n    # Combine the outputs and apply activation function\n    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_convtranspose2d_layer","title":"upsample_convtranspose2d_layer","text":"

Bases: Module

An upsampling convtranspose2d layer.

Source code in odak/learn/models/components.py
class upsample_convtranspose2d_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convtranspose2d layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 2,\n                 stride = 2,\n                 bias = False,\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super().__init__()\n        self.up = torch.nn.ConvTranspose2d(\n                                           in_channels = input_channels,\n                                           out_channels = output_channels,\n                                           bias = bias,\n                                           kernel_size = kernel_size,\n                                           stride = stride\n                                          )\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\"\n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        result = x1 + x2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_convtranspose2d_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 2,\n             stride = 2,\n             bias = False,\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super().__init__()\n    self.up = torch.nn.ConvTranspose2d(\n                                       in_channels = input_channels,\n                                       out_channels = output_channels,\n                                       bias = bias,\n                                       kernel_size = kernel_size,\n                                       stride = stride\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_convtranspose2d_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\"\n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    result = x1 + x2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_layer","title":"upsample_layer","text":"

Bases: Module

An upsampling convolutional layer.

Source code in odak/learn/models/components.py
class upsample_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convolutional layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(),\n                 bilinear = True\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        bilinear        : bool\n                          If set to True, bilinear sampling is used.\n        \"\"\"\n        super(upsample_layer, self).__init__()\n        if bilinear:\n            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n            self.conv = double_convolution(\n                                           input_channels = input_channels + output_channels,\n                                           mid_channels = input_channels // 2,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n        else:\n            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n            self.conv = double_convolution(\n                                           input_channels = input_channels,\n                                           mid_channels = output_channels,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\" \n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        x = torch.cat([x2, x1], dim = 1)\n        result = self.conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
  • bilinear \u2013
              If set to True, bilinear sampling is used.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(),\n             bilinear = True\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    bilinear        : bool\n                      If set to True, bilinear sampling is used.\n    \"\"\"\n    super(upsample_layer, self).__init__()\n    if bilinear:\n        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n        self.conv = double_convolution(\n                                       input_channels = input_channels + output_channels,\n                                       mid_channels = input_channels // 2,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n    else:\n        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n        self.conv = double_convolution(\n                                       input_channels = input_channels,\n                                       mid_channels = output_channels,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.components.upsample_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\" \n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    x = torch.cat([x2, x1], dim = 1)\n    result = self.conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.gaussian","title":"gaussian(x, multiplier=1.0)","text":"

A Gaussian non-linear activation. For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

Parameters:

  • x \u2013
           Input data.\n
  • multiplier \u2013
           Multiplier.\n

Returns:

  • result ( float or tensor ) \u2013

    Ouput data.

Source code in odak/learn/models/components.py
def gaussian(x, multiplier = 1.):\n    \"\"\"\n    A Gaussian non-linear activation.\n    For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n\n    Parameters\n    ----------\n    x            : float or torch.tensor\n                   Input data.\n    multiplier   : float or torch.tensor\n                   Multiplier.\n\n    Returns\n    -------\n    result       : float or torch.tensor\n                   Ouput data.\n    \"\"\"\n    result = torch.exp(- (multiplier * x) ** 2)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.components.swish","title":"swish(x)","text":"

A swish non-linear activation. For more details: https://en.wikipedia.org/wiki/Swish_function

Parameters:

  • x \u2013
             Input.\n

Returns:

  • out ( float or tensor ) \u2013

    Output.

Source code in odak/learn/models/components.py
def swish(x):\n    \"\"\"\n    A swish non-linear activation.\n    For more details: https://en.wikipedia.org/wiki/Swish_function\n\n    Parameters\n    -----------\n    x              : float or torch.tensor\n                     Input.\n\n    Returns\n    -------\n    out            : float or torch.tensor\n                     Output.\n    \"\"\"\n    out = x * torch.sigmoid(x)\n    return out\n
"},{"location":"odak/learn_models/#odak.learn.models.models.channel_gate","title":"channel_gate","text":"

Bases: Module

Channel attention module with various pooling strategies. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class channel_gate(torch.nn.Module):\n    \"\"\"\n    Channel attention module with various pooling strategies.\n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max']\n                ):\n        \"\"\"\n        Initializes the channel gate module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the intermediate layer.\n        pool_types      : list\n                          List of pooling operations to apply.\n        \"\"\"\n        super().__init__()\n        self.gate_channels = gate_channels\n        hidden_channels = gate_channels // reduction_ratio\n        if hidden_channels == 0:\n            hidden_channels = 1\n        self.mlp = torch.nn.Sequential(\n                                       convolutional_block_attention.Flatten(),\n                                       torch.nn.Linear(gate_channels, hidden_channels),\n                                       torch.nn.ReLU(),\n                                       torch.nn.Linear(hidden_channels, gate_channels)\n                                      )\n        self.pool_types = pool_types\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the ChannelGate module.\n\n        Applies channel-wise attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the ChannelGate module.\n\n        Returns\n        -------\n        output       : torch.tensor\n                       Output tensor after applying channel attention.\n        \"\"\"\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type == 'avg':\n                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            elif pool_type == 'max':\n                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            channel_att_raw = self.mlp(pool)\n            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n        output = x * scale\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.channel_gate.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'])","text":"

Initializes the channel gate module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the intermediate layer.\n
  • pool_types \u2013
              List of pooling operations to apply.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max']\n            ):\n    \"\"\"\n    Initializes the channel gate module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the intermediate layer.\n    pool_types      : list\n                      List of pooling operations to apply.\n    \"\"\"\n    super().__init__()\n    self.gate_channels = gate_channels\n    hidden_channels = gate_channels // reduction_ratio\n    if hidden_channels == 0:\n        hidden_channels = 1\n    self.mlp = torch.nn.Sequential(\n                                   convolutional_block_attention.Flatten(),\n                                   torch.nn.Linear(gate_channels, hidden_channels),\n                                   torch.nn.ReLU(),\n                                   torch.nn.Linear(hidden_channels, gate_channels)\n                                  )\n    self.pool_types = pool_types\n
"},{"location":"odak/learn_models/#odak.learn.models.models.channel_gate.forward","title":"forward(x)","text":"

Forward pass of the ChannelGate module.

Applies channel-wise attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the ChannelGate module.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor after applying channel attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the ChannelGate module.\n\n    Applies channel-wise attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the ChannelGate module.\n\n    Returns\n    -------\n    output       : torch.tensor\n                   Output tensor after applying channel attention.\n    \"\"\"\n    channel_att_sum = None\n    for pool_type in self.pool_types:\n        if pool_type == 'avg':\n            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        elif pool_type == 'max':\n            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        channel_att_raw = self.mlp(pool)\n        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n    output = x * scale\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolution_layer","title":"convolution_layer","text":"

Bases: Module

A convolution layer.

Source code in odak/learn/models/components.py
class convolution_layer(torch.nn.Module):\n    \"\"\"\n    A convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 bias = False,\n                 stride = 1,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        layers = [\n            torch.nn.Conv2d(\n                            input_channels,\n                            output_channels,\n                            kernel_size = kernel_size,\n                            stride = stride,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           )\n        ]\n        if normalization:\n            layers.append(torch.nn.BatchNorm2d(output_channels))\n        if activation:\n            layers.append(activation)\n        self.model = torch.nn.Sequential(*layers)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolution_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             bias = False,\n             stride = 1,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    layers = [\n        torch.nn.Conv2d(\n                        input_channels,\n                        output_channels,\n                        kernel_size = kernel_size,\n                        stride = stride,\n                        padding = kernel_size // 2,\n                        bias = bias\n                       )\n    ]\n    if normalization:\n        layers.append(torch.nn.BatchNorm2d(output_channels))\n    if activation:\n        layers.append(activation)\n    self.model = torch.nn.Sequential(*layers)\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolution_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolutional_block_attention","title":"convolutional_block_attention","text":"

Bases: Module

Convolutional Block Attention Module (CBAM) class. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class convolutional_block_attention(torch.nn.Module):\n    \"\"\"\n    Convolutional Block Attention Module (CBAM) class. \n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max'], \n                 no_spatial = False\n                ):\n        \"\"\"\n        Initializes the convolutional block attention module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the channel attention.\n        pool_types      : list\n                          List of pooling operations to apply for channel attention.\n        no_spatial      : bool\n                          If True, spatial attention is not applied.\n        \"\"\"\n        super(convolutional_block_attention, self).__init__()\n        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial = no_spatial\n        if not no_spatial:\n            self.spatial_gate = spatial_gate()\n\n\n    class Flatten(torch.nn.Module):\n        \"\"\"\n        Flattens the input tensor to a 2D matrix.\n        \"\"\"\n        def forward(self, x):\n            return x.view(x.size(0), -1)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the convolutional block attention module.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the CBAM module.\n\n        Returns\n        -------\n        x_out        : torch.tensor\n                       Output tensor after applying channel and spatial attention.\n        \"\"\"\n        x_out = self.channel_gate(x)\n        if not self.no_spatial:\n            x_out = self.spatial_gate(x_out)\n        return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolutional_block_attention.Flatten","title":"Flatten","text":"

Bases: Module

Flattens the input tensor to a 2D matrix.

Source code in odak/learn/models/components.py
class Flatten(torch.nn.Module):\n    \"\"\"\n    Flattens the input tensor to a 2D matrix.\n    \"\"\"\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolutional_block_attention.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False)","text":"

Initializes the convolutional block attention module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the channel attention.\n
  • pool_types \u2013
              List of pooling operations to apply for channel attention.\n
  • no_spatial \u2013
              If True, spatial attention is not applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max'], \n             no_spatial = False\n            ):\n    \"\"\"\n    Initializes the convolutional block attention module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the channel attention.\n    pool_types      : list\n                      List of pooling operations to apply for channel attention.\n    no_spatial      : bool\n                      If True, spatial attention is not applied.\n    \"\"\"\n    super(convolutional_block_attention, self).__init__()\n    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n    self.no_spatial = no_spatial\n    if not no_spatial:\n        self.spatial_gate = spatial_gate()\n
"},{"location":"odak/learn_models/#odak.learn.models.models.convolutional_block_attention.forward","title":"forward(x)","text":"

Forward pass of the convolutional block attention module.

Parameters:

  • x \u2013
           Input tensor to the CBAM module.\n

Returns:

  • x_out ( tensor ) \u2013

    Output tensor after applying channel and spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the convolutional block attention module.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the CBAM module.\n\n    Returns\n    -------\n    x_out        : torch.tensor\n                   Output tensor after applying channel and spatial attention.\n    \"\"\"\n    x_out = self.channel_gate(x)\n    if not self.no_spatial:\n        x_out = self.spatial_gate(x_out)\n    return x_out\n
"},{"location":"odak/learn_models/#odak.learn.models.models.double_convolution","title":"double_convolution","text":"

Bases: Module

A double convolution layer.

Source code in odak/learn/models/components.py
class double_convolution(torch.nn.Module):\n    \"\"\"\n    A double convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = None,\n                 output_channels = 2,\n                 kernel_size = 3, \n                 bias = False,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        Double convolution model.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of channels in the hidden layer between two convolutions.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        if isinstance(mid_channels, type(None)):\n            mid_channels = output_channels\n        self.activation = activation\n        self.model = torch.nn.Sequential(\n                                         convolution_layer(\n                                                           input_channels = input_channels,\n                                                           output_channels = mid_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          ),\n                                         convolution_layer(\n                                                           input_channels = mid_channels,\n                                                           output_channels = output_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          )\n                                        )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.double_convolution.__init__","title":"__init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU())","text":"

Double convolution model.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of channels in the hidden layer between two convolutions.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = None,\n             output_channels = 2,\n             kernel_size = 3, \n             bias = False,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    Double convolution model.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of channels in the hidden layer between two convolutions.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    if isinstance(mid_channels, type(None)):\n        mid_channels = output_channels\n    self.activation = activation\n    self.model = torch.nn.Sequential(\n                                     convolution_layer(\n                                                       input_channels = input_channels,\n                                                       output_channels = mid_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      ),\n                                     convolution_layer(\n                                                       input_channels = mid_channels,\n                                                       output_channels = output_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      )\n                                    )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.double_convolution.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.downsample_layer","title":"downsample_layer","text":"

Bases: Module

A downscaling component followed by a double convolution.

Source code in odak/learn/models/components.py
class downsample_layer(torch.nn.Module):\n    \"\"\"\n    A downscaling component followed by a double convolution.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.maxpool_conv = torch.nn.Sequential(\n                                                torch.nn.MaxPool2d(2),\n                                                double_convolution(\n                                                                   input_channels = input_channels,\n                                                                   mid_channels = output_channels,\n                                                                   output_channels = output_channels,\n                                                                   kernel_size = kernel_size,\n                                                                   bias = bias,\n                                                                   activation = activation\n                                                                  )\n                                               )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x              : torch.tensor\n                         First input data.\n\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.maxpool_conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.downsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.maxpool_conv = torch.nn.Sequential(\n                                            torch.nn.MaxPool2d(2),\n                                            double_convolution(\n                                                               input_channels = input_channels,\n                                                               mid_channels = output_channels,\n                                                               output_channels = output_channels,\n                                                               kernel_size = kernel_size,\n                                                               bias = bias,\n                                                               activation = activation\n                                                              )\n                                           )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.downsample_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
             First input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x              : torch.tensor\n                     First input data.\n\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.maxpool_conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_feature_module","title":"global_feature_module","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies them to another input tensor via learned transformations.

Source code in odak/learn/models/components.py
class global_feature_module(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies them to another input tensor via learned transformations.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 mid_channels,\n                 output_channels,\n                 kernel_size,\n                 bias = False,\n                 normalization = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels  : int\n                          Number of mid channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.transformations_1 = global_transformations(input_channels, output_channels)\n        self.global_features_1 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.global_features_2 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.transformations_2 = global_transformations(input_channels, output_channels)\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        global_tensor_1 = self.transformations_1(x1, x2)\n        y1 = self.global_features_1(global_tensor_1)\n        y2 = self.global_features_2(y1)\n        global_tensor_2 = self.transformations_2(y1, y2)\n        return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_feature_module.__init__","title":"__init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU())","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of mid channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             mid_channels,\n             output_channels,\n             kernel_size,\n             bias = False,\n             normalization = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels  : int\n                      Number of mid channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.transformations_1 = global_transformations(input_channels, output_channels)\n    self.global_features_1 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.global_features_2 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.transformations_2 = global_transformations(input_channels, output_channels)\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_feature_module.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    global_tensor_1 = self.transformations_1(x1, x2)\n    y1 = self.global_features_1(global_tensor_1)\n    y2 = self.global_features_2(y1)\n    global_tensor_2 = self.transformations_2(y1, y2)\n    return global_tensor_2\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_transformations","title":"global_transformations","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies learned transformations to another input tensor.

This implementation is adapted from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Reference: J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"

Source code in odak/learn/models/components.py
class global_transformations(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies learned transformations to another input tensor.\n\n    This implementation is adapted from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Reference:\n    J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        \"\"\"\n        super().__init__()\n        self.global_feature_1 = torch.nn.Sequential(\n            torch.nn.Linear(input_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True),\n        )\n        self.global_feature_2 = torch.nn.Sequential(\n            torch.nn.Linear(output_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True)\n        )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        y = torch.mean(x2, dim = (2, 3))\n        y1 = self.global_feature_1(y)\n        y2 = self.global_feature_2(y1)\n        y1 = y1.unsqueeze(2).unsqueeze(3)\n        y2 = y2.unsqueeze(2).unsqueeze(3)\n        result = x1 * y1 + y2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_transformations.__init__","title":"__init__(input_channels, output_channels)","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    \"\"\"\n    super().__init__()\n    self.global_feature_1 = torch.nn.Sequential(\n        torch.nn.Linear(input_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True),\n    )\n    self.global_feature_2 = torch.nn.Sequential(\n        torch.nn.Linear(output_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True)\n    )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.global_transformations.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    y = torch.mean(x2, dim = (2, 3))\n    y1 = self.global_feature_1(y)\n    y2 = self.global_feature_2(y1)\n    y1 = y1.unsqueeze(2).unsqueeze(3)\n    y2 = y2.unsqueeze(2).unsqueeze(3)\n    result = x1 * y1 + y2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.multi_layer_perceptron","title":"multi_layer_perceptron","text":"

Bases: Module

A multi-layer perceptron model.

Source code in odak/learn/models/models.py
class multi_layer_perceptron(torch.nn.Module):\n    \"\"\"\n    A multi-layer perceptron model.\n    \"\"\"\n\n    def __init__(self,\n                 dimensions,\n                 activation = torch.nn.ReLU(),\n                 bias = False,\n                 model_type = 'conventional',\n                 siren_multiplier = 1.,\n                 input_multiplier = None\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        dimensions        : list\n                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n        activation        : torch.nn\n                            Nonlinear activation function.\n                            Default is `torch.nn.ReLU()`.\n        bias              : bool\n                            If set to True, linear layers will include biases.\n        siren_multiplier  : float\n                            When using `SIREN` model type, this parameter functions as a hyperparameter.\n                            The original SIREN work uses 30.\n                            You can bypass this parameter by providing input that are not normalized and larger then one.\n        input_multiplier  : float\n                            Initial value of the input multiplier before the very first layer.\n        model_type        : str\n                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                            `conventional` refers to a standard multi layer perceptron.\n                            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n        \"\"\"\n        super(multi_layer_perceptron, self).__init__()\n        self.activation = activation\n        self.bias = bias\n        self.model_type = model_type\n        self.layers = torch.nn.ModuleList()\n        self.siren_multiplier = siren_multiplier\n        self.dimensions = dimensions\n        for i in range(len(self.dimensions) - 1):\n            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n        if not isinstance(input_multiplier, type(None)):\n            self.input_multiplier = torch.nn.ParameterList()\n            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n        if self.model_type == 'FILM SIREN':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n        if self.model_type == 'Gaussian':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        if hasattr(self, 'input_multiplier'):\n            result = x * self.input_multiplier[0]\n        else:\n            result = x\n        for layer_id, layer in enumerate(self.layers[:-1]):\n            result = layer(result)\n            if self.model_type == 'conventional':\n                result = self.activation(result)\n            elif self.model_type == 'swish':\n                resutl = swish(result)\n            elif self.model_type == 'SIREN':\n                result = torch.sin(result * self.siren_multiplier)\n            elif self.model_type == 'FILM SIREN':\n                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n            elif self.model_type == 'Gaussian': \n                result = gaussian(result, self.alpha[layer_id][0])\n        result = self.layers[-1](result)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.multi_layer_perceptron.__init__","title":"__init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None)","text":"

Parameters:

  • dimensions \u2013
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n
  • activation \u2013
                Nonlinear activation function.\n            Default is `torch.nn.ReLU()`.\n
  • bias \u2013
                If set to True, linear layers will include biases.\n
  • siren_multiplier \u2013
                When using `SIREN` model type, this parameter functions as a hyperparameter.\n            The original SIREN work uses 30.\n            You can bypass this parameter by providing input that are not normalized and larger then one.\n
  • input_multiplier \u2013
                Initial value of the input multiplier before the very first layer.\n
  • model_type \u2013
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n            `conventional` refers to a standard multi layer perceptron.\n            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n
Source code in odak/learn/models/models.py
def __init__(self,\n             dimensions,\n             activation = torch.nn.ReLU(),\n             bias = False,\n             model_type = 'conventional',\n             siren_multiplier = 1.,\n             input_multiplier = None\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    dimensions        : list\n                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n    activation        : torch.nn\n                        Nonlinear activation function.\n                        Default is `torch.nn.ReLU()`.\n    bias              : bool\n                        If set to True, linear layers will include biases.\n    siren_multiplier  : float\n                        When using `SIREN` model type, this parameter functions as a hyperparameter.\n                        The original SIREN work uses 30.\n                        You can bypass this parameter by providing input that are not normalized and larger then one.\n    input_multiplier  : float\n                        Initial value of the input multiplier before the very first layer.\n    model_type        : str\n                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                        `conventional` refers to a standard multi layer perceptron.\n                        For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                        For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n    \"\"\"\n    super(multi_layer_perceptron, self).__init__()\n    self.activation = activation\n    self.bias = bias\n    self.model_type = model_type\n    self.layers = torch.nn.ModuleList()\n    self.siren_multiplier = siren_multiplier\n    self.dimensions = dimensions\n    for i in range(len(self.dimensions) - 1):\n        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n    if not isinstance(input_multiplier, type(None)):\n        self.input_multiplier = torch.nn.ParameterList()\n        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n    if self.model_type == 'FILM SIREN':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n    if self.model_type == 'Gaussian':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n
"},{"location":"odak/learn_models/#odak.learn.models.models.multi_layer_perceptron.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    if hasattr(self, 'input_multiplier'):\n        result = x * self.input_multiplier[0]\n    else:\n        result = x\n    for layer_id, layer in enumerate(self.layers[:-1]):\n        result = layer(result)\n        if self.model_type == 'conventional':\n            result = self.activation(result)\n        elif self.model_type == 'swish':\n            resutl = swish(result)\n        elif self.model_type == 'SIREN':\n            result = torch.sin(result * self.siren_multiplier)\n        elif self.model_type == 'FILM SIREN':\n            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n        elif self.model_type == 'Gaussian': \n            result = gaussian(result, self.alpha[layer_id][0])\n    result = self.layers[-1](result)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.non_local_layer","title":"non_local_layer","text":"

Bases: Module

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

Source code in odak/learn/models/components.py
class non_local_layer(torch.nn.Module):\n    \"\"\"\n    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 1024,\n                 bottleneck_channels = 512,\n                 kernel_size = 1,\n                 bias = False,\n                ):\n        \"\"\"\n\n        Parameters\n        ----------\n        input_channels      : int\n                              Number of input channels.\n        bottleneck_channels : int\n                              Number of middle channels.\n        kernel_size         : int\n                              Kernel size.\n        bias                : bool \n                              Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super(non_local_layer, self).__init__()\n        self.input_channels = input_channels\n        self.bottleneck_channels = bottleneck_channels\n        self.g = torch.nn.Conv2d(\n                                 self.input_channels, \n                                 self.bottleneck_channels,\n                                 kernel_size = kernel_size,\n                                 padding = kernel_size // 2,\n                                 bias = bias\n                                )\n        self.W_z = torch.nn.Sequential(\n                                       torch.nn.Conv2d(\n                                                       self.bottleneck_channels,\n                                                       self.input_channels, \n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       padding = kernel_size // 2\n                                                      ),\n                                       torch.nn.BatchNorm2d(self.input_channels)\n                                      )\n        torch.nn.init.constant_(self.W_z[1].weight, 0)   \n        torch.nn.init.constant_(self.W_z[1].bias, 0)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model [zi = Wzyi + xi]\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          First input data.                       \n\n\n        Returns\n        ----------\n        z               : torch.tensor\n                          Estimated output.\n        \"\"\"\n        batch_size, channels, height, width = x.size()\n        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n        attn = torch.nn.functional.softmax(attn, dim=-1)\n        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n        W_y = self.W_z(y)\n        z = W_y + x\n        return z\n
"},{"location":"odak/learn_models/#odak.learn.models.models.non_local_layer.__init__","title":"__init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False)","text":"

Parameters:

  • input_channels \u2013
                  Number of input channels.\n
  • bottleneck_channels (int, default: 512 ) \u2013
                  Number of middle channels.\n
  • kernel_size \u2013
                  Kernel size.\n
  • bias \u2013
                  Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 1024,\n             bottleneck_channels = 512,\n             kernel_size = 1,\n             bias = False,\n            ):\n    \"\"\"\n\n    Parameters\n    ----------\n    input_channels      : int\n                          Number of input channels.\n    bottleneck_channels : int\n                          Number of middle channels.\n    kernel_size         : int\n                          Kernel size.\n    bias                : bool \n                          Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super(non_local_layer, self).__init__()\n    self.input_channels = input_channels\n    self.bottleneck_channels = bottleneck_channels\n    self.g = torch.nn.Conv2d(\n                             self.input_channels, \n                             self.bottleneck_channels,\n                             kernel_size = kernel_size,\n                             padding = kernel_size // 2,\n                             bias = bias\n                            )\n    self.W_z = torch.nn.Sequential(\n                                   torch.nn.Conv2d(\n                                                   self.bottleneck_channels,\n                                                   self.input_channels, \n                                                   kernel_size = kernel_size,\n                                                   bias = bias,\n                                                   padding = kernel_size // 2\n                                                  ),\n                                   torch.nn.BatchNorm2d(self.input_channels)\n                                  )\n    torch.nn.init.constant_(self.W_z[1].weight, 0)   \n    torch.nn.init.constant_(self.W_z[1].bias, 0)\n
"},{"location":"odak/learn_models/#odak.learn.models.models.non_local_layer.forward","title":"forward(x)","text":"

Forward model [zi = Wzyi + xi]

Parameters:

  • x \u2013
              First input data.\n

Returns:

  • z ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model [zi = Wzyi + xi]\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      First input data.                       \n\n\n    Returns\n    ----------\n    z               : torch.tensor\n                      Estimated output.\n    \"\"\"\n    batch_size, channels, height, width = x.size()\n    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n    attn = torch.nn.functional.softmax(attn, dim=-1)\n    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n    W_y = self.W_z(y)\n    z = W_y + x\n    return z\n
"},{"location":"odak/learn_models/#odak.learn.models.models.normalization","title":"normalization","text":"

Bases: Module

A normalization layer.

Source code in odak/learn/models/components.py
class normalization(torch.nn.Module):\n    \"\"\"\n    A normalization layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 dim = 1,\n                ):\n        \"\"\"\n        Normalization layer.\n\n\n        Parameters\n        ----------\n        dim             : int\n                          Dimension (axis) to normalize.\n        \"\"\"\n        super().__init__()\n        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = 1, keepdim = True)\n        result =  (x - mean) * (var + eps).rsqrt() * self.k\n        return result \n
"},{"location":"odak/learn_models/#odak.learn.models.models.normalization.__init__","title":"__init__(dim=1)","text":"

Normalization layer.

Parameters:

  • dim \u2013
              Dimension (axis) to normalize.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             dim = 1,\n            ):\n    \"\"\"\n    Normalization layer.\n\n\n    Parameters\n    ----------\n    dim             : int\n                      Dimension (axis) to normalize.\n    \"\"\"\n    super().__init__()\n    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n
"},{"location":"odak/learn_models/#odak.learn.models.models.normalization.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n    mean = torch.mean(x, dim = 1, keepdim = True)\n    result =  (x - mean) * (var + eps).rsqrt() * self.k\n    return result \n
"},{"location":"odak/learn_models/#odak.learn.models.models.positional_encoder","title":"positional_encoder","text":"

Bases: Module

A positional encoder module.

Source code in odak/learn/models/components.py
class positional_encoder(torch.nn.Module):\n    \"\"\"\n    A positional encoder module.\n    \"\"\"\n\n    def __init__(self, L):\n        \"\"\"\n        A positional encoder module.\n\n        Parameters\n        ----------\n        L                   : int\n                              Positional encoding level.\n        \"\"\"\n        super(positional_encoder, self).__init__()\n        self.L = L\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          Input data.\n\n        Returns\n        ----------\n        result          : torch.tensor\n                          Result of the forward operation\n        \"\"\"\n        B, C = x.shape\n        x = x.view(B, C, 1)\n        results = [x]\n        for i in range(1, self.L + 1):\n            freq = (2 ** i) * math.pi\n            cos_x = torch.cos(freq * x)\n            sin_x = torch.sin(freq * x)\n            results.append(cos_x)\n            results.append(sin_x)\n        results = torch.cat(results, dim=2)\n        results = results.permute(0, 2, 1)\n        results = results.reshape(B, -1)\n        return results\n
"},{"location":"odak/learn_models/#odak.learn.models.models.positional_encoder.__init__","title":"__init__(L)","text":"

A positional encoder module.

Parameters:

  • L \u2013
                  Positional encoding level.\n
Source code in odak/learn/models/components.py
def __init__(self, L):\n    \"\"\"\n    A positional encoder module.\n\n    Parameters\n    ----------\n    L                   : int\n                          Positional encoding level.\n    \"\"\"\n    super(positional_encoder, self).__init__()\n    self.L = L\n
"},{"location":"odak/learn_models/#odak.learn.models.models.positional_encoder.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
              Input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      Input data.\n\n    Returns\n    ----------\n    result          : torch.tensor\n                      Result of the forward operation\n    \"\"\"\n    B, C = x.shape\n    x = x.view(B, C, 1)\n    results = [x]\n    for i in range(1, self.L + 1):\n        freq = (2 ** i) * math.pi\n        cos_x = torch.cos(freq * x)\n        sin_x = torch.sin(freq * x)\n        results.append(cos_x)\n        results.append(sin_x)\n    results = torch.cat(results, dim=2)\n    results = results.permute(0, 2, 1)\n    results = results.reshape(B, -1)\n    return results\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_attention_layer","title":"residual_attention_layer","text":"

Bases: Module

A residual block with an attention layer.

Source code in odak/learn/models/components.py
class residual_attention_layer(torch.nn.Module):\n    \"\"\"\n    A residual block with an attention layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 1,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        An attention layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int or optioal\n                          Number of input channels.\n        output_channels : int or optional\n                          Number of middle channels.\n        kernel_size     : int or optional\n                          Kernel size.\n        bias            : bool or optional\n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn or optional\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution0 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.convolution1 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.final_layer = torch.nn.Sequential(\n                                               self.activation,\n                                               torch.nn.Conv2d(\n                                                               output_channels,\n                                                               output_channels,\n                                                               kernel_size = kernel_size,\n                                                               padding = kernel_size // 2,\n                                                               bias = bias\n                                                              )\n                                              )\n\n\n    def forward(self, x0, x1):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x0             : torch.tensor\n                         First input data.\n\n        x1             : torch.tensor\n                         Seconnd input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        y0 = self.convolution0(x0)\n        y1 = self.convolution1(x1)\n        y2 = torch.add(y0, y1)\n        result = self.final_layer(y2) * x0\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_attention_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU())","text":"

An attention layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int or optional, default: 2 ) \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 1,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    An attention layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int or optioal\n                      Number of input channels.\n    output_channels : int or optional\n                      Number of middle channels.\n    kernel_size     : int or optional\n                      Kernel size.\n    bias            : bool or optional\n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn or optional\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution0 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.convolution1 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.final_layer = torch.nn.Sequential(\n                                           self.activation,\n                                           torch.nn.Conv2d(\n                                                           output_channels,\n                                                           output_channels,\n                                                           kernel_size = kernel_size,\n                                                           padding = kernel_size // 2,\n                                                           bias = bias\n                                                          )\n                                          )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_attention_layer.forward","title":"forward(x0, x1)","text":"

Forward model.

Parameters:

  • x0 \u2013
             First input data.\n
  • x1 \u2013
             Seconnd input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x0, x1):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x0             : torch.tensor\n                     First input data.\n\n    x1             : torch.tensor\n                     Seconnd input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    y0 = self.convolution0(x0)\n    y1 = self.convolution1(x1)\n    y2 = torch.add(y0, y1)\n    result = self.final_layer(y2) * x0\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_layer","title":"residual_layer","text":"

Bases: Module

A residual layer.

Source code in odak/learn/models/components.py
class residual_layer(torch.nn.Module):\n    \"\"\"\n    A residual layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = 16,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of middle channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution = double_convolution(\n                                              input_channels,\n                                              mid_channels = mid_channels,\n                                              output_channels = input_channels,\n                                              kernel_size = kernel_size,\n                                              bias = bias,\n                                              activation = activation\n                                             )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        x0 = self.convolution(x)\n        return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_layer.__init__","title":"__init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = 16,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of middle channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution = double_convolution(\n                                          input_channels,\n                                          mid_channels = mid_channels,\n                                          output_channels = input_channels,\n                                          kernel_size = kernel_size,\n                                          bias = bias,\n                                          activation = activation\n                                         )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.residual_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    x0 = self.convolution(x)\n    return x + x0\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatial_gate","title":"spatial_gate","text":"

Bases: Module

Spatial attention module that applies a convolution layer after channel pooling. This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

Source code in odak/learn/models/components.py
class spatial_gate(torch.nn.Module):\n    \"\"\"\n    Spatial attention module that applies a convolution layer after channel pooling.\n    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.\n    \"\"\"\n    def __init__(self):\n        \"\"\"\n        Initializes the spatial gate module.\n        \"\"\"\n        super().__init__()\n        kernel_size = 7\n        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n\n\n    def channel_pool(self, x):\n        \"\"\"\n        Applies max and average pooling on the channels.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input tensor.\n\n        Returns\n        -------\n        output        : torch.tensor\n                        Output tensor.\n        \"\"\"\n        max_pool = torch.max(x, 1)[0].unsqueeze(1)\n        avg_pool = torch.mean(x, 1).unsqueeze(1)\n        output = torch.cat((max_pool, avg_pool), dim=1)\n        return output\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the SpatialGate module.\n\n        Applies spatial attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the SpatialGate module.\n\n        Returns\n        -------\n        scaled_x     : torch.tensor\n                       Output tensor after applying spatial attention.\n        \"\"\"\n        x_compress = self.channel_pool(x)\n        x_out = self.spatial(x_compress)\n        scale = torch.sigmoid(x_out)\n        scaled_x = x * scale\n        return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatial_gate.__init__","title":"__init__()","text":"

Initializes the spatial gate module.

Source code in odak/learn/models/components.py
def __init__(self):\n    \"\"\"\n    Initializes the spatial gate module.\n    \"\"\"\n    super().__init__()\n    kernel_size = 7\n    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatial_gate.channel_pool","title":"channel_pool(x)","text":"

Applies max and average pooling on the channels.

Parameters:

  • x \u2013
            Input tensor.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor.

Source code in odak/learn/models/components.py
def channel_pool(self, x):\n    \"\"\"\n    Applies max and average pooling on the channels.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input tensor.\n\n    Returns\n    -------\n    output        : torch.tensor\n                    Output tensor.\n    \"\"\"\n    max_pool = torch.max(x, 1)[0].unsqueeze(1)\n    avg_pool = torch.mean(x, 1).unsqueeze(1)\n    output = torch.cat((max_pool, avg_pool), dim=1)\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatial_gate.forward","title":"forward(x)","text":"

Forward pass of the SpatialGate module.

Applies spatial attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the SpatialGate module.\n

Returns:

  • scaled_x ( tensor ) \u2013

    Output tensor after applying spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the SpatialGate module.\n\n    Applies spatial attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the SpatialGate module.\n\n    Returns\n    -------\n    scaled_x     : torch.tensor\n                   Output tensor after applying spatial attention.\n    \"\"\"\n    x_compress = self.channel_pool(x)\n    x_out = self.spatial(x_compress)\n    scale = torch.sigmoid(x_out)\n    scaled_x = x * scale\n    return scaled_x\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_convolution","title":"spatially_adaptive_convolution","text":"

Bases: Module

A spatially adaptive convolution layer.

References

C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\" C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"

Source code in odak/learn/models/components.py
class spatially_adaptive_convolution(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive convolution layer.\n\n    References\n    ----------\n\n    C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\"\n    C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\"\n    C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn.Module\n                          Activation function to apply. If None, no activation is applied.\n        \"\"\"\n        super(spatially_adaptive_convolution, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        sa_output          : torch.tensor\n                            Estimated output tensor.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Resize weight to match the input channels and kernel size\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply spatially varying kernels\n        sv_feature = input_feature * sv_kernel\n\n        # Perform matrix multiplication\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n        return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_convolution.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive convolution layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Activation function to apply. If None, no activation is applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn.Module\n                      Activation function to apply. If None, no activation is applied.\n    \"\"\"\n    super(spatially_adaptive_convolution, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_convolution.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive convolution layer.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • sa_output ( tensor ) \u2013

    Estimated output tensor. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    sa_output          : torch.tensor\n                        Estimated output tensor.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Resize weight to match the input channels and kernel size\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply spatially varying kernels\n    sv_feature = input_feature * sv_kernel\n\n    # Perform matrix multiplication\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n    return sa_output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_module","title":"spatially_adaptive_module","text":"

Bases: Module

A spatially adaptive module that combines learned spatially adaptive convolutions.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/components.py
class spatially_adaptive_module(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive module that combines learned spatially adaptive convolutions.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive module.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super(spatially_adaptive_module, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.weight_output_channels = self.output_channels - 1\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.weight_output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive module.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        output             : torch.tensor\n                            Combined output tensor from standard and spatially adaptive convolutions.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Apply sv_kernel to the input_feature\n        sv_feature = input_feature * sv_kernel\n\n        # Original spatially varying convolution output\n        sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                           1,\n                                                            1,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n        # Reshape weight for spatially adaptive convolution\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply si_kernel on sv convolution output\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n\n        # Combine the outputs and apply activation function\n        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n        return output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_module.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive module.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive module.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super(spatially_adaptive_module, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.weight_output_channels = self.output_channels - 1\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.weight_output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_module.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive module.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • output ( tensor ) \u2013

    Combined output tensor from standard and spatially adaptive convolutions. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive module.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    output             : torch.tensor\n                        Combined output tensor from standard and spatially adaptive convolutions.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Apply sv_kernel to the input_feature\n    sv_feature = input_feature * sv_kernel\n\n    # Original spatially varying convolution output\n    sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                       1,\n                                                        1,\n                                                        (x.size(-2) // self.stride),\n                                                        (x.size(-1) // self.stride)\n                                                       )\n\n    # Reshape weight for spatially adaptive convolution\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply si_kernel on sv convolution output\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n    # Combine the outputs and apply activation function\n    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n    return output\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_unet","title":"spatially_adaptive_unet","text":"

Bases: Module

Spatially varying U-Net model based on spatially adaptive convolution.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/models.py
class spatially_adaptive_unet(torch.nn.Module):\n    \"\"\"\n    Spatially varying U-Net model based on spatially adaptive convolution.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 depth=3,\n                 dimensions=8,\n                 input_channels=6,\n                 out_channels=6,\n                 kernel_size=3,\n                 bias=True,\n                 normalization=False,\n                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        out_channels   : int\n                         Number of output channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.out_channels = out_channels\n        self.inc = convolution_layer(\n                                     input_channels=input_channels,\n                                     output_channels=dimensions,\n                                     kernel_size=kernel_size,\n                                     bias=bias,\n                                     normalization=normalization,\n                                     activation=activation\n                                    )\n\n        self.encoder = torch.nn.ModuleList()\n        for i in range(self.depth + 1):  # Downsampling layers\n            down_in_channels = dimensions * (2 ** i)\n            down_out_channels = 2 * down_in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels=down_in_channels,\n                                                          mid_channels=down_in_channels,\n                                                          output_channels=down_in_channels,\n                                                          kernel_size=kernel_size,\n                                                          bias=bias,\n                                                          normalization=normalization,\n                                                          activation=activation\n                                                         )\n            sam = spatially_adaptive_module(\n                                            input_channels=down_in_channels,\n                                            output_channels=down_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                           )\n            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n        self.global_feature_module = torch.nn.ModuleList()\n        double_convolution_layer = double_convolution(\n                                                      input_channels=dimensions * (2 ** (depth + 1)),\n                                                      mid_channels=dimensions * (2 ** (depth + 1)),\n                                                      output_channels=dimensions * (2 ** (depth + 1)),\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        global_feature_layer = global_feature_module(\n                                                     input_channels=dimensions * (2 ** (depth + 1)),\n                                                     mid_channels=dimensions * (2 ** (depth + 1)),\n                                                     output_channels=dimensions * (2 ** (depth + 1)),\n                                                     kernel_size=kernel_size,\n                                                     bias=bias,\n                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                    )\n        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n        self.decoder = torch.nn.ModuleList()\n        for i in range(depth, -1, -1):\n            up_in_channels = dimensions * (2 ** (i + 1))\n            up_mid_channels = up_in_channels // 2\n            if i == 0:\n                up_out_channels = self.out_channels\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = torch.nn.Sequential(\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_mid_channels,\n                                      kernel_size=kernel_size,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=activation,\n                                     ),\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_out_channels,\n                                      kernel_size=1,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=None,\n                                     )\n                )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n            else:\n                up_out_channels = up_in_channels // 2\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = double_convolution(\n                                                input_channels=up_mid_channels,\n                                                mid_channels=up_mid_channels,\n                                                output_channels=up_out_channels,\n                                                kernel_size=kernel_size,\n                                                bias=bias,\n                                                normalization=normalization,\n                                                activation=activation,\n                                               )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, sv_kernel, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n\n        field     : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        target_field : torch.tensor\n                       Estimated output.\n                       Dimension: (1, 6, H, W)\n        \"\"\"\n        x = self.inc(field)\n        downsampling_outputs = [x]\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer[0](downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n            downsampling_outputs.append(sam_output)\n        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n        downsampling_outputs.append(global_feature)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate(self.decoder):\n            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n            x_up = up_layer[1](x_up)\n        result = x_up\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_unet.__init__","title":"__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 6 ) \u2013
             Number of input channels.\n
  • out_channels \u2013
             Number of output channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth=3,\n             dimensions=8,\n             input_channels=6,\n             out_channels=6,\n             kernel_size=3,\n             bias=True,\n             normalization=False,\n             activation=torch.nn.LeakyReLU(0.2, inplace=True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    out_channels   : int\n                     Number of output channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.out_channels = out_channels\n    self.inc = convolution_layer(\n                                 input_channels=input_channels,\n                                 output_channels=dimensions,\n                                 kernel_size=kernel_size,\n                                 bias=bias,\n                                 normalization=normalization,\n                                 activation=activation\n                                )\n\n    self.encoder = torch.nn.ModuleList()\n    for i in range(self.depth + 1):  # Downsampling layers\n        down_in_channels = dimensions * (2 ** i)\n        down_out_channels = 2 * down_in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels=down_in_channels,\n                                                      mid_channels=down_in_channels,\n                                                      output_channels=down_in_channels,\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        sam = spatially_adaptive_module(\n                                        input_channels=down_in_channels,\n                                        output_channels=down_out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                       )\n        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n    self.global_feature_module = torch.nn.ModuleList()\n    double_convolution_layer = double_convolution(\n                                                  input_channels=dimensions * (2 ** (depth + 1)),\n                                                  mid_channels=dimensions * (2 ** (depth + 1)),\n                                                  output_channels=dimensions * (2 ** (depth + 1)),\n                                                  kernel_size=kernel_size,\n                                                  bias=bias,\n                                                  normalization=normalization,\n                                                  activation=activation\n                                                 )\n    global_feature_layer = global_feature_module(\n                                                 input_channels=dimensions * (2 ** (depth + 1)),\n                                                 mid_channels=dimensions * (2 ** (depth + 1)),\n                                                 output_channels=dimensions * (2 ** (depth + 1)),\n                                                 kernel_size=kernel_size,\n                                                 bias=bias,\n                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                )\n    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n    self.decoder = torch.nn.ModuleList()\n    for i in range(depth, -1, -1):\n        up_in_channels = dimensions * (2 ** (i + 1))\n        up_mid_channels = up_in_channels // 2\n        if i == 0:\n            up_out_channels = self.out_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = torch.nn.Sequential(\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_mid_channels,\n                                  kernel_size=kernel_size,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=activation,\n                                 ),\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_out_channels,\n                                  kernel_size=1,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=None,\n                                 )\n            )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n        else:\n            up_out_channels = up_in_channels // 2\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels=up_mid_channels,\n                                            mid_channels=up_mid_channels,\n                                            output_channels=up_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            normalization=normalization,\n                                            activation=activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_adaptive_unet.forward","title":"forward(sv_kernel, field)","text":"

Forward model.

Parameters:

  • sv_kernel (list of torch.tensor) \u2013
        Learned spatially varying kernels.\n    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n    where C_i, H_i, and W_i represent the channel, height, and width\n    of each feature at a certain scale.\n
  • field \u2013
        Input field data.\n    Dimension: (1, 6, H, W)\n

Returns:

  • target_field ( tensor ) \u2013

    Estimated output. Dimension: (1, 6, H, W)

Source code in odak/learn/models/models.py
def forward(self, sv_kernel, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n\n    field     : torch.tensor\n                Input field data.\n                Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    target_field : torch.tensor\n                   Estimated output.\n                   Dimension: (1, 6, H, W)\n    \"\"\"\n    x = self.inc(field)\n    downsampling_outputs = [x]\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer[0](downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n        downsampling_outputs.append(sam_output)\n    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n    downsampling_outputs.append(global_feature)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate(self.decoder):\n        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n        x_up = up_layer[1](x_up)\n    result = x_up\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_varying_kernel_generation_model","title":"spatially_varying_kernel_generation_model","text":"

Bases: Module

Spatially_varying_kernel_generation_model revised from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Refer to: J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

Source code in odak/learn/models/models.py
class spatially_varying_kernel_generation_model(torch.nn.Module):\n    \"\"\"\n    Spatially_varying_kernel_generation_model revised from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Refer to:\n    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\n    \"\"\"\n\n    def __init__(\n                 self,\n                 depth = 3,\n                 dimensions = 8,\n                 input_channels = 7,\n                 kernel_size = 3,\n                 bias = True,\n                 normalization = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.inc = convolution_layer(\n                                     input_channels = input_channels,\n                                     output_channels = dimensions,\n                                     kernel_size = kernel_size,\n                                     bias = bias,\n                                     normalization = normalization,\n                                     activation = activation\n                                    )\n        self.encoder = torch.nn.ModuleList()\n        for i in range(depth + 1):  # downsampling layers\n            if i == 0:\n                in_channels = dimensions * (2 ** i)\n                out_channels = dimensions * (2 ** i)\n            elif i == depth:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = dimensions * (2 ** (i - 1))\n            else:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = 2 * in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels = in_channels,\n                                                          mid_channels = in_channels,\n                                                          output_channels = out_channels,\n                                                          kernel_size = kernel_size,\n                                                          bias = bias,\n                                                          normalization = normalization,\n                                                          activation = activation\n                                                         )\n            self.encoder.append(pooling_layer)\n            self.encoder.append(double_convolution_layer)\n        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n        for i in range(depth, -1, -1):\n            if i == 1:\n                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n            else:\n                svf_in_channels = 2 ** (self.depth + i) + 1\n            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n            svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n            spatially_varying_kernel_generation = torch.nn.ModuleList()\n            for j in range(i, -1, -1):\n                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n                spatially_varying_kernel_generation.append(pooling_layer)\n            kernel_generation_block = torch.nn.Sequential(\n                torch.nn.Conv2d(\n                                in_channels = svf_in_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_out_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n            )\n            spatially_varying_kernel_generation.append(kernel_generation_block)\n            self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n        self.decoder = torch.nn.ModuleList()\n        global_feature_layer = global_feature_module(  # global feature layer\n                                                     input_channels = dimensions * (2 ** (depth - 1)),\n                                                     mid_channels = dimensions * (2 ** (depth - 1)),\n                                                     output_channels = dimensions * (2 ** (depth - 1)),\n                                                     kernel_size = kernel_size,\n                                                     bias = bias,\n                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                    )\n        self.decoder.append(global_feature_layer)\n        for i in range(depth, 0, -1):\n            if i == 2:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels\n                up_mid_channels = up_in_channels\n            elif i == 1:\n                up_in_channels = dimensions * 2\n                up_out_channels = dimensions\n                up_mid_channels = up_out_channels\n            else:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels // 2\n                up_mid_channels = up_in_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels = up_in_channels,\n                                                            output_channels = up_mid_channels,\n                                                            kernel_size = 2,\n                                                            stride = 2,\n                                                            bias = bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels = up_mid_channels,\n                                            output_channels = up_out_channels,\n                                            kernel_size = kernel_size,\n                                            bias = bias,\n                                            normalization = normalization,\n                                            activation = activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, focal_surface, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        focal_surface : torch.tensor\n                        Input focal surface data.\n                        Dimension: (1, 1, H, W)\n\n        field         : torch.tensor\n                        Input field data.\n                        Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n        \"\"\"\n        x = self.inc(torch.cat((focal_surface, field), dim = 1))\n        downsampling_outputs = [focal_surface]\n        downsampling_outputs.append(x)\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        sv_kernels = []\n        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n            if i == 0:\n                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n                downsampling_outputs[-1] = global_feature\n                sv_feature = [global_feature, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                              sv_feature[3]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n            else:\n                x_up = up_layer[0](downsampling_outputs[-1],\n                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n                x_up = up_layer[1](x_up)\n                downsampling_outputs[-1] = x_up\n                sv_feature = [x_up, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                if i == 1:\n                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n        return sv_kernels\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_varying_kernel_generation_model.__init__","title":"__init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 7 ) \u2013
             Number of input channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth = 3,\n             dimensions = 8,\n             input_channels = 7,\n             kernel_size = 3,\n             bias = True,\n             normalization = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.inc = convolution_layer(\n                                 input_channels = input_channels,\n                                 output_channels = dimensions,\n                                 kernel_size = kernel_size,\n                                 bias = bias,\n                                 normalization = normalization,\n                                 activation = activation\n                                )\n    self.encoder = torch.nn.ModuleList()\n    for i in range(depth + 1):  # downsampling layers\n        if i == 0:\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** i)\n        elif i == depth:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = dimensions * (2 ** (i - 1))\n        else:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = 2 * in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels = in_channels,\n                                                      mid_channels = in_channels,\n                                                      output_channels = out_channels,\n                                                      kernel_size = kernel_size,\n                                                      bias = bias,\n                                                      normalization = normalization,\n                                                      activation = activation\n                                                     )\n        self.encoder.append(pooling_layer)\n        self.encoder.append(double_convolution_layer)\n    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n    for i in range(depth, -1, -1):\n        if i == 1:\n            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n        else:\n            svf_in_channels = 2 ** (self.depth + i) + 1\n        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n        svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n        spatially_varying_kernel_generation = torch.nn.ModuleList()\n        for j in range(i, -1, -1):\n            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n            spatially_varying_kernel_generation.append(pooling_layer)\n        kernel_generation_block = torch.nn.Sequential(\n            torch.nn.Conv2d(\n                            in_channels = svf_in_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_out_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n        )\n        spatially_varying_kernel_generation.append(kernel_generation_block)\n        self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n    self.decoder = torch.nn.ModuleList()\n    global_feature_layer = global_feature_module(  # global feature layer\n                                                 input_channels = dimensions * (2 ** (depth - 1)),\n                                                 mid_channels = dimensions * (2 ** (depth - 1)),\n                                                 output_channels = dimensions * (2 ** (depth - 1)),\n                                                 kernel_size = kernel_size,\n                                                 bias = bias,\n                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                )\n    self.decoder.append(global_feature_layer)\n    for i in range(depth, 0, -1):\n        if i == 2:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels\n            up_mid_channels = up_in_channels\n        elif i == 1:\n            up_in_channels = dimensions * 2\n            up_out_channels = dimensions\n            up_mid_channels = up_out_channels\n        else:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels // 2\n            up_mid_channels = up_in_channels\n        upsample_layer = upsample_convtranspose2d_layer(\n                                                        input_channels = up_in_channels,\n                                                        output_channels = up_mid_channels,\n                                                        kernel_size = 2,\n                                                        stride = 2,\n                                                        bias = bias,\n                                                       )\n        conv_layer = double_convolution(\n                                        input_channels = up_mid_channels,\n                                        output_channels = up_out_channels,\n                                        kernel_size = kernel_size,\n                                        bias = bias,\n                                        normalization = normalization,\n                                        activation = activation,\n                                       )\n        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_models/#odak.learn.models.models.spatially_varying_kernel_generation_model.forward","title":"forward(focal_surface, field)","text":"

Forward model.

Parameters:

  • focal_surface (tensor) \u2013
            Input focal surface data.\n        Dimension: (1, 1, H, W)\n
  • field \u2013
            Input field data.\n        Dimension: (1, 6, H, W)\n

Returns:

  • sv_kernel ( list of torch.tensor ) \u2013

    Learned spatially varying kernels. Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), where C_i, H_i, and W_i represent the channel, height, and width of each feature at a certain scale.

Source code in odak/learn/models/models.py
def forward(self, focal_surface, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    focal_surface : torch.tensor\n                    Input focal surface data.\n                    Dimension: (1, 1, H, W)\n\n    field         : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n    \"\"\"\n    x = self.inc(torch.cat((focal_surface, field), dim = 1))\n    downsampling_outputs = [focal_surface]\n    downsampling_outputs.append(x)\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    sv_kernels = []\n    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n        if i == 0:\n            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n            downsampling_outputs[-1] = global_feature\n            sv_feature = [global_feature, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                          sv_feature[3]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n        else:\n            x_up = up_layer[0](downsampling_outputs[-1],\n                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n            x_up = up_layer[1](x_up)\n            downsampling_outputs[-1] = x_up\n            sv_feature = [x_up, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            if i == 1:\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n    return sv_kernels\n
"},{"location":"odak/learn_models/#odak.learn.models.models.unet","title":"unet","text":"

Bases: Module

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

Source code in odak/learn/models/models.py
class unet(torch.nn.Module):\n    \"\"\"\n    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.\n    \"\"\"\n\n    def __init__(\n                 self, \n                 depth = 4,\n                 dimensions = 64, \n                 input_channels = 2, \n                 output_channels = 1, \n                 bilinear = False,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(inplace = True),\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth             : int\n                            Number of upsampling and downsampling\n        dimensions        : int\n                            Number of dimensions.\n        input_channels    : int\n                            Number of input channels.\n        output_channels   : int\n                            Number of output channels.\n        bilinear          : bool\n                            Uses bilinear upsampling in upsampling layers when set True.\n        bias              : bool\n                            Set True to let convolutional layers learn a bias term.\n        activation        : torch.nn\n                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n        \"\"\"\n        super(unet, self).__init__()\n        self.inc = double_convolution(\n                                      input_channels = input_channels,\n                                      mid_channels = dimensions,\n                                      output_channels = dimensions,\n                                      kernel_size = kernel_size,\n                                      bias = bias,\n                                      activation = activation\n                                     )      \n\n        self.downsampling_layers = torch.nn.ModuleList()\n        self.upsampling_layers = torch.nn.ModuleList()\n        for i in range(depth): # downsampling layers\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** (i + 1))\n            down_layer = downsample_layer(in_channels,\n                                            out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                            )\n            self.downsampling_layers.append(down_layer)      \n\n        for i in range(depth - 1, -1, -1):  # upsampling layers\n            up_in_channels = dimensions * (2 ** (i + 1))  \n            up_out_channels = dimensions * (2 ** i) \n            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n            self.upsampling_layers.append(up_layer)\n        self.outc = torch.nn.Conv2d(\n                                    dimensions, \n                                    output_channels,\n                                    kernel_size = kernel_size,\n                                    padding = kernel_size // 2,\n                                    bias = bias\n                                   )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        downsampling_outputs = [self.inc(x)]\n        for down_layer in self.downsampling_layers:\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate((self.upsampling_layers)):\n            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n        result = self.outc(x_up)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.unet.__init__","title":"__init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
                Number of upsampling and downsampling\n
  • dimensions \u2013
                Number of dimensions.\n
  • input_channels \u2013
                Number of input channels.\n
  • output_channels \u2013
                Number of output channels.\n
  • bilinear \u2013
                Uses bilinear upsampling in upsampling layers when set True.\n
  • bias \u2013
                Set True to let convolutional layers learn a bias term.\n
  • activation \u2013
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n
Source code in odak/learn/models/models.py
def __init__(\n             self, \n             depth = 4,\n             dimensions = 64, \n             input_channels = 2, \n             output_channels = 1, \n             bilinear = False,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(inplace = True),\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth             : int\n                        Number of upsampling and downsampling\n    dimensions        : int\n                        Number of dimensions.\n    input_channels    : int\n                        Number of input channels.\n    output_channels   : int\n                        Number of output channels.\n    bilinear          : bool\n                        Uses bilinear upsampling in upsampling layers when set True.\n    bias              : bool\n                        Set True to let convolutional layers learn a bias term.\n    activation        : torch.nn\n                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n    \"\"\"\n    super(unet, self).__init__()\n    self.inc = double_convolution(\n                                  input_channels = input_channels,\n                                  mid_channels = dimensions,\n                                  output_channels = dimensions,\n                                  kernel_size = kernel_size,\n                                  bias = bias,\n                                  activation = activation\n                                 )      \n\n    self.downsampling_layers = torch.nn.ModuleList()\n    self.upsampling_layers = torch.nn.ModuleList()\n    for i in range(depth): # downsampling layers\n        in_channels = dimensions * (2 ** i)\n        out_channels = dimensions * (2 ** (i + 1))\n        down_layer = downsample_layer(in_channels,\n                                        out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                        )\n        self.downsampling_layers.append(down_layer)      \n\n    for i in range(depth - 1, -1, -1):  # upsampling layers\n        up_in_channels = dimensions * (2 ** (i + 1))  \n        up_out_channels = dimensions * (2 ** i) \n        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n        self.upsampling_layers.append(up_layer)\n    self.outc = torch.nn.Conv2d(\n                                dimensions, \n                                output_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.unet.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    downsampling_outputs = [self.inc(x)]\n    for down_layer in self.downsampling_layers:\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate((self.upsampling_layers)):\n        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n    result = self.outc(x_up)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_convtranspose2d_layer","title":"upsample_convtranspose2d_layer","text":"

Bases: Module

An upsampling convtranspose2d layer.

Source code in odak/learn/models/components.py
class upsample_convtranspose2d_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convtranspose2d layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 2,\n                 stride = 2,\n                 bias = False,\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super().__init__()\n        self.up = torch.nn.ConvTranspose2d(\n                                           in_channels = input_channels,\n                                           out_channels = output_channels,\n                                           bias = bias,\n                                           kernel_size = kernel_size,\n                                           stride = stride\n                                          )\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\"\n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        result = x1 + x2\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_convtranspose2d_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 2,\n             stride = 2,\n             bias = False,\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super().__init__()\n    self.up = torch.nn.ConvTranspose2d(\n                                       in_channels = input_channels,\n                                       out_channels = output_channels,\n                                       bias = bias,\n                                       kernel_size = kernel_size,\n                                       stride = stride\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_convtranspose2d_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\"\n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    result = x1 + x2\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_layer","title":"upsample_layer","text":"

Bases: Module

An upsampling convolutional layer.

Source code in odak/learn/models/components.py
class upsample_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convolutional layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(),\n                 bilinear = True\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        bilinear        : bool\n                          If set to True, bilinear sampling is used.\n        \"\"\"\n        super(upsample_layer, self).__init__()\n        if bilinear:\n            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n            self.conv = double_convolution(\n                                           input_channels = input_channels + output_channels,\n                                           mid_channels = input_channels // 2,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n        else:\n            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n            self.conv = double_convolution(\n                                           input_channels = input_channels,\n                                           mid_channels = output_channels,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\" \n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        x = torch.cat([x2, x1], dim = 1)\n        result = self.conv(x)\n        return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
  • bilinear \u2013
              If set to True, bilinear sampling is used.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(),\n             bilinear = True\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    bilinear        : bool\n                      If set to True, bilinear sampling is used.\n    \"\"\"\n    super(upsample_layer, self).__init__()\n    if bilinear:\n        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n        self.conv = double_convolution(\n                                       input_channels = input_channels + output_channels,\n                                       mid_channels = input_channels // 2,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n    else:\n        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n        self.conv = double_convolution(\n                                       input_channels = input_channels,\n                                       mid_channels = output_channels,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n
"},{"location":"odak/learn_models/#odak.learn.models.models.upsample_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\" \n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    x = torch.cat([x2, x1], dim = 1)\n    result = self.conv(x)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.gaussian","title":"gaussian(x, multiplier=1.0)","text":"

A Gaussian non-linear activation. For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

Parameters:

  • x \u2013
           Input data.\n
  • multiplier \u2013
           Multiplier.\n

Returns:

  • result ( float or tensor ) \u2013

    Ouput data.

Source code in odak/learn/models/components.py
def gaussian(x, multiplier = 1.):\n    \"\"\"\n    A Gaussian non-linear activation.\n    For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n\n    Parameters\n    ----------\n    x            : float or torch.tensor\n                   Input data.\n    multiplier   : float or torch.tensor\n                   Multiplier.\n\n    Returns\n    -------\n    result       : float or torch.tensor\n                   Ouput data.\n    \"\"\"\n    result = torch.exp(- (multiplier * x) ** 2)\n    return result\n
"},{"location":"odak/learn_models/#odak.learn.models.models.swish","title":"swish(x)","text":"

A swish non-linear activation. For more details: https://en.wikipedia.org/wiki/Swish_function

Parameters:

  • x \u2013
             Input.\n

Returns:

  • out ( float or tensor ) \u2013

    Output.

Source code in odak/learn/models/components.py
def swish(x):\n    \"\"\"\n    A swish non-linear activation.\n    For more details: https://en.wikipedia.org/wiki/Swish_function\n\n    Parameters\n    -----------\n    x              : float or torch.tensor\n                     Input.\n\n    Returns\n    -------\n    out            : float or torch.tensor\n                     Output.\n    \"\"\"\n    out = x * torch.sigmoid(x)\n    return out\n
"},{"location":"odak/learn_perception/","title":"odak.learn.perception","text":"

odak.learn.perception

Defines a number of different perceptual loss functions, which can be used to optimise images where gaze location is known.

"},{"location":"odak/learn_perception/#odak.learn.perception.BlurLoss","title":"BlurLoss","text":"

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/blur_loss.py
class BlurLoss():\n    \"\"\" \n\n    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.\n\n    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.\n\n    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device(\"cpu\"),\n                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\", blur_source=False, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        blur_source             : bool\n                                    If true, blurs the source image as well as the target before computing the loss.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.mode = mode\n        self.blur = None\n        self.loss_func = torch.nn.MSELoss()\n        self.blur_source = blur_source\n        self.equi = equi\n\n    def blur_image(self, image, gaze):\n        if self.blur is None:\n            self.blur = RadiallyVaryingBlur()\n        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)\n\n    def __call__(self, image, target, gaze=[0.5, 0.5]):\n        \"\"\" \n        Calculates the Blur Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze                : list\n                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"BlurLoss\", image, target)\n        blurred_target = self.blur_image(target, gaze)\n        if self.blur_source:\n            blurred_image = self.blur_image(image, gaze)\n            return self.loss_func(blurred_image, blurred_target)\n        else:\n            return self.loss_func(image, blurred_target)\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.BlurLoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5])","text":"

Calculates the Blur Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/blur_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):\n    \"\"\" \n    Calculates the Blur Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze                : list\n                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"BlurLoss\", image, target)\n    blurred_target = self.blur_image(target, gaze)\n    if self.blur_source:\n        blurred_image = self.blur_image(image, gaze)\n        return self.loss_func(blurred_image, blurred_target)\n    else:\n        return self.loss_func(image, blurred_target)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.BlurLoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • blur_source \u2013
                        If true, blurs the source image as well as the target before computing the loss.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/blur_loss.py
def __init__(self, device=torch.device(\"cpu\"),\n             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\", blur_source=False, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    blur_source             : bool\n                                If true, blurs the source image as well as the target before computing the loss.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.alpha = alpha\n    self.real_image_width = real_image_width\n    self.real_viewing_distance = real_viewing_distance\n    self.mode = mode\n    self.blur = None\n    self.loss_func = torch.nn.MSELoss()\n    self.blur_source = blur_source\n    self.equi = equi\n
"},{"location":"odak/learn_perception/#odak.learn.perception.CVVDP","title":"CVVDP","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class CVVDP(nn.Module):\n    def __init__(self, device = torch.device('cpu')):\n        \"\"\"\n        Initializes the CVVDP model with a specified device.\n\n        Parameters\n        ----------\n        device   : torch.device\n                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n        \"\"\"\n        super(CVVDP, self).__init__()\n        try:\n            import pycvvdp\n            self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)\n        except Exception as e:\n            logging.warning('ColorVideoVDP is missing, consider installing by running \"pip install -U git+https://github.com/gfxdisp/ColorVideoVDP\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets, dim_order = 'CHW'):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets    h  : torch.tensor\n                        The ground truth images.\n        dim_order     : str\n                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n        Returns\n        -------\n        result        : torch.tensor\n                        The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)\n            return l_ColorVideoVDP\n        except Exception as e:\n            logging.warning('ColorVideoVDP failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.CVVDP.__init__","title":"__init__(device=torch.device('cpu'))","text":"

Initializes the CVVDP model with a specified device.

Parameters:

  • device \u2013
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):\n    \"\"\"\n    Initializes the CVVDP model with a specified device.\n\n    Parameters\n    ----------\n    device   : torch.device\n                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n    \"\"\"\n    super(CVVDP, self).__init__()\n    try:\n        import pycvvdp\n        self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)\n    except Exception as e:\n        logging.warning('ColorVideoVDP is missing, consider installing by running \"pip install -U git+https://github.com/gfxdisp/ColorVideoVDP\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.CVVDP.forward","title":"forward(predictions, targets, dim_order='CHW')","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n
  • dim_order \u2013
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'CHW'):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets    h  : torch.tensor\n                    The ground truth images.\n    dim_order     : str\n                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n    Returns\n    -------\n    result        : torch.tensor\n                    The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)\n        return l_ColorVideoVDP\n    except Exception as e:\n        logging.warning('ColorVideoVDP failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.FVVDP","title":"FVVDP","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class FVVDP(nn.Module):\n    def __init__(self, device = torch.device('cpu')):\n        \"\"\"\n        Initializes the FVVDP model with a specified device.\n\n        Parameters\n        ----------\n        device   : torch.device\n                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n        \"\"\"\n        super(FVVDP, self).__init__()\n        try:\n            import pyfvvdp\n            self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)\n        except Exception as e:\n            logging.warning('FovVideoVDP is missing, consider installing by running \"pip install pyfvvdp\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets, dim_order = 'CHW'):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets       : torch.tensor\n                        The ground truth images.\n        dim_order     : str\n                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n        Returns\n        -------\n        result        : torch.tensor\n                          The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]\n            return l_FovVideoVDP\n        except Exception as e:\n            logging.warning('FovVideoVDP failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.FVVDP.__init__","title":"__init__(device=torch.device('cpu'))","text":"

Initializes the FVVDP model with a specified device.

Parameters:

  • device \u2013
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):\n    \"\"\"\n    Initializes the FVVDP model with a specified device.\n\n    Parameters\n    ----------\n    device   : torch.device\n                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n    \"\"\"\n    super(FVVDP, self).__init__()\n    try:\n        import pyfvvdp\n        self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)\n    except Exception as e:\n        logging.warning('FovVideoVDP is missing, consider installing by running \"pip install pyfvvdp\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.FVVDP.forward","title":"forward(predictions, targets, dim_order='CHW')","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n
  • dim_order \u2013
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'CHW'):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets       : torch.tensor\n                    The ground truth images.\n    dim_order     : str\n                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n    Returns\n    -------\n    result        : torch.tensor\n                      The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]\n        return l_FovVideoVDP\n    except Exception as e:\n        logging.warning('FovVideoVDP failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.LPIPS","title":"LPIPS","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class LPIPS(nn.Module):\n\n    def __init__(self):\n        \"\"\"\n        Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.\n\n        \"\"\"\n        super(LPIPS, self).__init__()\n        try:\n            import torchmetrics\n            self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')\n        except Exception as e:\n            logging.warning('torchmetrics is missing, consider installing by running \"pip install torchmetrics\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets       : torch.tensor\n                        The ground truth images.\n\n        Returns\n        -------\n        result        : torch.tensor\n                        The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            lpips_image = predictions\n            lpips_target = targets\n            if len(lpips_image.shape) == 3:\n                lpips_image = lpips_image.unsqueeze(0)\n                lpips_target = lpips_target.unsqueeze(0)\n            if lpips_image.shape[1] == 1:\n                lpips_image = lpips_image.repeat(1, 3, 1, 1)\n                lpips_target = lpips_target.repeat(1, 3, 1, 1)\n            lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)\n            lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)\n            l_LPIPS = self.lpips(lpips_image, lpips_target)\n            return l_LPIPS\n        except Exception as e:\n            logging.warning('LPIPS failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.LPIPS.__init__","title":"__init__()","text":"

Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self):\n    \"\"\"\n    Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.\n\n    \"\"\"\n    super(LPIPS, self).__init__()\n    try:\n        import torchmetrics\n        self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')\n    except Exception as e:\n        logging.warning('torchmetrics is missing, consider installing by running \"pip install torchmetrics\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.LPIPS.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets       : torch.tensor\n                    The ground truth images.\n\n    Returns\n    -------\n    result        : torch.tensor\n                    The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        lpips_image = predictions\n        lpips_target = targets\n        if len(lpips_image.shape) == 3:\n            lpips_image = lpips_image.unsqueeze(0)\n            lpips_target = lpips_target.unsqueeze(0)\n        if lpips_image.shape[1] == 1:\n            lpips_image = lpips_image.repeat(1, 3, 1, 1)\n            lpips_target = lpips_target.repeat(1, 3, 1, 1)\n        lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)\n        lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)\n        l_LPIPS = self.lpips(lpips_image, lpips_target)\n        return l_LPIPS\n    except Exception as e:\n        logging.warning('LPIPS failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MSSSIM","title":"MSSSIM","text":"

Bases: Module

A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class MSSSIM(nn.Module):\n    '''\n    A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(MSSSIM, self).__init__()\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions : torch.tensor\n                      The predicted images.\n        targets     : torch.tensor\n                      The ground truth images.\n\n        Returns\n        -------\n        result      : torch.tensor \n                      The computed MS-SSIM value if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            from torchmetrics.functional.image import multiscale_structural_similarity_index_measure\n            if len(predictions.shape) == 3:\n                predictions = predictions.unsqueeze(0)\n                targets = targets.unsqueeze(0)\n            l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)\n            return l_MSSSIM  \n        except Exception as e:\n            logging.warning('MS-SSIM failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MSSSIM.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions (tensor) \u2013
          The predicted images.\n
  • targets \u2013
          The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed MS-SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions : torch.tensor\n                  The predicted images.\n    targets     : torch.tensor\n                  The ground truth images.\n\n    Returns\n    -------\n    result      : torch.tensor \n                  The computed MS-SSIM value if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        from torchmetrics.functional.image import multiscale_structural_similarity_index_measure\n        if len(predictions.shape) == 3:\n            predictions = predictions.unsqueeze(0)\n            targets = targets.unsqueeze(0)\n        l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)\n        return l_MSSSIM  \n    except Exception as e:\n        logging.warning('MS-SSIM failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamerMSELoss","title":"MetamerMSELoss","text":"

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metamer_mse_loss.py
class MetamerMSELoss():\n    \"\"\" \n    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.\n\n    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.\n\n    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device(\"cpu\"),\n                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\",\n                 n_pyramid_levels=5, n_orientations=2, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        n_pyramid_levels        : int \n                                    Number of levels of the steerable pyramid. Note that the image is padded\n                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                    too high will slow down the calculation a lot.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        n_orientations          : int \n                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                    Increasing this will increase runtime.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.target_metamer = None\n        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,\n                                            real_viewing_distance=real_viewing_distance,\n                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)\n        self.loss_func = torch.nn.MSELoss()\n        self.noise = None\n\n    def gen_metamer(self, image, gaze):\n        \"\"\" \n        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n        This function can be used on its own to generate a metamer for a desired image.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze    : list\n                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        metamer : torch.tensor\n                The generated metamer image\n        \"\"\"\n        image = rgb_2_ycrcb(image)\n        image_size = image.size()\n        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n\n        target_stats = self.metameric_loss.calc_statsmaps(\n            image, gaze=gaze, alpha=self.metameric_loss.alpha)\n        target_means = target_stats[::2]\n        target_stdevs = target_stats[1::2]\n        if self.noise is None or self.noise.size() != image.size():\n            torch.manual_seed(0)\n            noise_image = torch.rand_like(image)\n        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n            noise_image, self.metameric_loss.n_pyramid_levels)\n        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n            image, self.metameric_loss.n_pyramid_levels)\n\n        def match_level(input_level, target_mean, target_std):\n            level = input_level.clone()\n            level -= torch.mean(level)\n            input_std = torch.sqrt(torch.mean(level * level))\n            eps = 1e-6\n            # Safeguard against divide by zero\n            input_std[input_std < eps] = eps\n            level /= input_std\n            level *= target_std\n            level += target_mean\n            return level\n\n        nbands = len(noise_pyramid[0][\"b\"])\n        noise_pyramid[0][\"h\"] = match_level(\n            noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n        for l in range(len(noise_pyramid)-1):\n            for b in range(nbands):\n                noise_pyramid[l][\"b\"][b] = match_level(\n                    noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n        noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(\n            noise_pyramid)\n        metamer = ycrcb_2_rgb(metamer)\n        # Crop to remove any padding\n        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n        return metamer\n\n    def __call__(self, image, target, gaze=[0.5, 0.5]):\n        \"\"\" \n        Calculates the Metamer MSE Loss.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target  : torch.tensor\n                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze    : list\n                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamerMSELoss\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)\n\n        if target is not self.target or self.target is None:\n            self.target_metamer = self.gen_metamer(target, gaze)\n            self.target = target\n\n        return self.loss_func(image, self.target_metamer)\n\n    def to(self, device):\n        self.metameric_loss = self.metameric_loss.to(device)\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamerMSELoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5])","text":"

Calculates the Metamer MSE Loss.

Parameters:

  • image \u2013
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metamer_mse_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):\n    \"\"\" \n    Calculates the Metamer MSE Loss.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target  : torch.tensor\n            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze    : list\n            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamerMSELoss\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)\n\n    if target is not self.target or self.target is None:\n        self.target_metamer = self.gen_metamer(target, gaze)\n        self.target = target\n\n    return self.loss_func(image, self.target_metamer)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamerMSELoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • n_pyramid_levels \u2013
                        Number of levels of the steerable pyramid. Note that the image is padded\n                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                    too high will slow down the calculation a lot.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • n_orientations \u2013
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                    Increasing this will increase runtime.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/metamer_mse_loss.py
def __init__(self, device=torch.device(\"cpu\"),\n             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\",\n             n_pyramid_levels=5, n_orientations=2, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    n_pyramid_levels        : int \n                                Number of levels of the steerable pyramid. Note that the image is padded\n                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                too high will slow down the calculation a lot.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    n_orientations          : int \n                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                Increasing this will increase runtime.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.target_metamer = None\n    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,\n                                        real_viewing_distance=real_viewing_distance,\n                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)\n    self.loss_func = torch.nn.MSELoss()\n    self.noise = None\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamerMSELoss.gen_metamer","title":"gen_metamer(image, gaze)","text":"

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image \u2013
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • metamer ( tensor ) \u2013

    The generated metamer image

Source code in odak/learn/perception/metamer_mse_loss.py
def gen_metamer(self, image, gaze):\n    \"\"\" \n    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n    This function can be used on its own to generate a metamer for a desired image.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze    : list\n            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    metamer : torch.tensor\n            The generated metamer image\n    \"\"\"\n    image = rgb_2_ycrcb(image)\n    image_size = image.size()\n    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n\n    target_stats = self.metameric_loss.calc_statsmaps(\n        image, gaze=gaze, alpha=self.metameric_loss.alpha)\n    target_means = target_stats[::2]\n    target_stdevs = target_stats[1::2]\n    if self.noise is None or self.noise.size() != image.size():\n        torch.manual_seed(0)\n        noise_image = torch.rand_like(image)\n    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n        noise_image, self.metameric_loss.n_pyramid_levels)\n    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n        image, self.metameric_loss.n_pyramid_levels)\n\n    def match_level(input_level, target_mean, target_std):\n        level = input_level.clone()\n        level -= torch.mean(level)\n        input_std = torch.sqrt(torch.mean(level * level))\n        eps = 1e-6\n        # Safeguard against divide by zero\n        input_std[input_std < eps] = eps\n        level /= input_std\n        level *= target_std\n        level += target_mean\n        return level\n\n    nbands = len(noise_pyramid[0][\"b\"])\n    noise_pyramid[0][\"h\"] = match_level(\n        noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n    for l in range(len(noise_pyramid)-1):\n        for b in range(nbands):\n            noise_pyramid[l][\"b\"][b] = match_level(\n                noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n    noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(\n        noise_pyramid)\n    metamer = ycrcb_2_rgb(metamer)\n    # Crop to remove any padding\n    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n    return metamer\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLoss","title":"MetamericLoss","text":"

The MetamericLoss class provides a perceptual loss function.

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metameric_loss.py
class MetamericLoss():\n    \"\"\"\n    The `MetamericLoss` class provides a perceptual loss function.\n\n    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.\n\n    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,\n                 real_viewing_distance=0.7, n_pyramid_levels=5, mode=\"quadratic\",\n                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,\n                 use_fullres_l0=False, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        n_pyramid_levels        : int \n                                    Number of levels of the steerable pyramid. Note that the image is padded\n                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                    too high will slow down the calculation a lot.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        n_orientations          : int \n                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                    Increasing this will increase runtime.\n        use_l2_foveal_loss      : bool \n                                    If true, for all the pixels that have pooling size 1 pixel in the \n                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                                    In practice this gives better results when the loss is used for holography.\n        fovea_weight            : float \n                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n        use_radial_weight       : bool \n                                    If True, will apply a radial weighting when calculating the difference between\n                                    the source and target stats maps. This weights stats closer to the fovea more than those\n                                    further away.\n        use_fullres_l0          : bool \n                                    If true, stats for the lowpass residual are replaced with blurred versions\n                                    of the full-resolution source and target images.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.pyramid_maker = None\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.blurs = None\n        self.n_pyramid_levels = n_pyramid_levels\n        self.n_orientations = n_orientations\n        self.mode = mode\n        self.use_l2_foveal_loss = use_l2_foveal_loss\n        self.fovea_weight = fovea_weight\n        self.use_radial_weight = use_radial_weight\n        self.use_fullres_l0 = use_fullres_l0\n        self.equi = equi\n        if self.use_fullres_l0 and self.use_l2_foveal_loss:\n            raise Exception(\n                \"Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!\")\n\n    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,\n                       real_viewing_distance=0.6, mode=\"quadratic\", equi=False):\n\n        if self.pyramid_maker is None or \\\n                self.pyramid_maker.device != self.device or \\\n                len(self.pyramid_maker.band_filters) != self.n_orientations or\\\n                self.pyramid_maker.filt_h0.size(0) != image.size(1):\n            self.pyramid_maker = SpatialSteerablePyramid(\n                use_bilinear_downup=False, n_channels=image.size(1),\n                device=self.device, n_orientations=self.n_orientations, filter_type=\"cropped\", filter_size=5)\n\n        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:\n            self.blurs = [RadiallyVaryingBlur()\n                          for i in range(self.n_pyramid_levels)]\n\n        def find_stats(image_pyr_level, blur):\n            image_means = blur.blur(\n                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)\n            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,\n                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)\n\n            image_vars = image_meansq - (image_means*image_means)\n            image_vars[image_vars < 1e-7] = 1e-7\n            image_std = torch.sqrt(image_vars)\n            if torch.any(torch.isnan(image_means)):\n                print(image_means)\n                raise Exception(\"NaN in image means!\")\n            if torch.any(torch.isnan(image_std)):\n                print(image_std)\n                raise Exception(\"NaN in image stdevs!\")\n            if self.use_fullres_l0:\n                mask = blur.lod_map > 1e-6\n                mask = mask[None, None, ...]\n                if image_means.size(1) > 1:\n                    mask = mask.repeat(1, image_means.size(1), 1, 1)\n                matte = torch.zeros_like(image_means)\n                matte[mask] = 1.0\n                return image_means * matte, image_std * matte\n            return image_means, image_std\n        output_stats = []\n        image_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])\n        if self.use_l2_foveal_loss:\n            self.fovea_mask = torch.zeros(image.size(), device=image.device)\n            for i in range(self.fovea_mask.size(1)):\n                self.fovea_mask[0, i, ...] = 1.0 - \\\n                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))\n                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0\n            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)\n            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode=\"area\")\n            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode=\"bilinear\")\n            periphery_mask = 1.0 - self.fovea_mask\n            self.periphery_mask = periphery_mask.clone()\n            output_stats.append(means * periphery_mask)\n            output_stats.append(variances * periphery_mask)\n        else:\n            output_stats.append(means)\n            output_stats.append(variances)\n\n        for l in range(0, len(image_pyramid)-1):\n            for o in range(len(image_pyramid[l]['b'])):\n                means, variances = find_stats(\n                    image_pyramid[l]['b'][o], self.blurs[l])\n                if self.use_l2_foveal_loss:\n                    output_stats.append(means * periphery_mask)\n                    output_stats.append(variances * periphery_mask)\n                else:\n                    output_stats.append(means)\n                    output_stats.append(variances)\n            if self.use_l2_foveal_loss:\n                periphery_mask = torch.nn.functional.interpolate(\n                    periphery_mask, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n\n        if self.use_l2_foveal_loss:\n            output_stats.append(image_pyramid[-1][\"l\"] * periphery_mask)\n        elif self.use_fullres_l0:\n            output_stats.append(self.blurs[0].blur(\n                image, alpha, real_image_width, real_viewing_distance, gaze, mode))\n        else:\n            output_stats.append(image_pyramid[-1][\"l\"])\n        return output_stats\n\n    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):\n        loss = 0.0\n        for a, b in zip(statsmap_a, statsmap_b):\n            if self.use_radial_weight:\n                radii = make_radial_map(\n                    [a.size(-2), a.size(-1)], gaze).to(a.device)\n                weights = 1.1 - (radii * radii * radii * radii)\n                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)\n                loss += torch.nn.MSELoss()(weights*a, weights*b)\n            else:\n                loss += torch.nn.MSELoss()(a, b)\n        loss /= len(statsmap_a)\n        return loss\n\n    def visualise_loss_map(self, image_stats):\n        loss_map = torch.zeros(image_stats[0].size()[-2:])\n        for i in range(len(image_stats)):\n            stats = image_stats[i]\n            target_stats = self.target_stats[i]\n            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))\n            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(\n            ), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            loss_map += stat_mse_map[0, 0, ...]\n        self.loss_map = loss_map\n\n    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace=\"RGB\", visualise_loss=False):\n        \"\"\" \n        Calculates the Metameric Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        image_colorspace    : str\n                                The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                                accepted values: RGB, YCrCb.\n        gaze                : list\n                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n        visualise_loss      : bool\n                                Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamericLoss\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n        # If input is RGB, convert to YCrCb.\n        if image.size(1) == 3 and image_colorspace == \"RGB\":\n            image = rgb_2_ycrcb(image)\n            target = rgb_2_ycrcb(target)\n        if self.target is None:\n            self.target = torch.zeros(target.shape).to(target.device)\n        if type(target) == type(self.target):\n            if not torch.all(torch.eq(target, self.target)):\n                self.target = target.detach().clone()\n                self.target_stats = self.calc_statsmaps(\n                    self.target,\n                    gaze=gaze,\n                    alpha=self.alpha,\n                    real_image_width=self.real_image_width,\n                    real_viewing_distance=self.real_viewing_distance,\n                    mode=self.mode\n                )\n                self.target = target.detach().clone()\n            image_stats = self.calc_statsmaps(\n                image,\n                gaze=gaze,\n                alpha=self.alpha,\n                real_image_width=self.real_image_width,\n                real_viewing_distance=self.real_viewing_distance,\n                mode=self.mode\n            )\n            if visualise_loss:\n                self.visualise_loss_map(image_stats)\n            if self.use_l2_foveal_loss:\n                peripheral_loss = self.metameric_loss_stats(\n                    image_stats, self.target_stats, gaze)\n                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)\n                # New weighting - evenly weight fovea and periphery.\n                loss = peripheral_loss + self.fovea_weight * foveal_loss\n            else:\n                loss = self.metameric_loss_stats(\n                    image_stats, self.target_stats, gaze)\n            return loss\n        else:\n            raise Exception(\"Target of incorrect type\")\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False)","text":"

Calculates the Metameric Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • image_colorspace \u2013
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                accepted values: RGB, YCrCb.\n
  • gaze \u2013
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n
  • visualise_loss \u2013
                    Shows a heatmap indicating which parts of the image contributed most to the loss.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metameric_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace=\"RGB\", visualise_loss=False):\n    \"\"\" \n    Calculates the Metameric Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    image_colorspace    : str\n                            The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                            accepted values: RGB, YCrCb.\n    gaze                : list\n                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n    visualise_loss      : bool\n                            Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamericLoss\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n    # If input is RGB, convert to YCrCb.\n    if image.size(1) == 3 and image_colorspace == \"RGB\":\n        image = rgb_2_ycrcb(image)\n        target = rgb_2_ycrcb(target)\n    if self.target is None:\n        self.target = torch.zeros(target.shape).to(target.device)\n    if type(target) == type(self.target):\n        if not torch.all(torch.eq(target, self.target)):\n            self.target = target.detach().clone()\n            self.target_stats = self.calc_statsmaps(\n                self.target,\n                gaze=gaze,\n                alpha=self.alpha,\n                real_image_width=self.real_image_width,\n                real_viewing_distance=self.real_viewing_distance,\n                mode=self.mode\n            )\n            self.target = target.detach().clone()\n        image_stats = self.calc_statsmaps(\n            image,\n            gaze=gaze,\n            alpha=self.alpha,\n            real_image_width=self.real_image_width,\n            real_viewing_distance=self.real_viewing_distance,\n            mode=self.mode\n        )\n        if visualise_loss:\n            self.visualise_loss_map(image_stats)\n        if self.use_l2_foveal_loss:\n            peripheral_loss = self.metameric_loss_stats(\n                image_stats, self.target_stats, gaze)\n            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)\n            # New weighting - evenly weight fovea and periphery.\n            loss = peripheral_loss + self.fovea_weight * foveal_loss\n        else:\n            loss = self.metameric_loss_stats(\n                image_stats, self.target_stats, gaze)\n        return loss\n    else:\n        raise Exception(\"Target of incorrect type\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • n_pyramid_levels \u2013
                        Number of levels of the steerable pyramid. Note that the image is padded\n                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                    too high will slow down the calculation a lot.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • n_orientations \u2013
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                    Increasing this will increase runtime.\n
  • use_l2_foveal_loss \u2013
                        If true, for all the pixels that have pooling size 1 pixel in the \n                    largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                    In practice this gives better results when the loss is used for holography.\n
  • fovea_weight \u2013
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n
  • use_radial_weight \u2013
                        If True, will apply a radial weighting when calculating the difference between\n                    the source and target stats maps. This weights stats closer to the fovea more than those\n                    further away.\n
  • use_fullres_l0 \u2013
                        If true, stats for the lowpass residual are replaced with blurred versions\n                    of the full-resolution source and target images.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/metameric_loss.py
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,\n             real_viewing_distance=0.7, n_pyramid_levels=5, mode=\"quadratic\",\n             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,\n             use_fullres_l0=False, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    n_pyramid_levels        : int \n                                Number of levels of the steerable pyramid. Note that the image is padded\n                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                too high will slow down the calculation a lot.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    n_orientations          : int \n                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                Increasing this will increase runtime.\n    use_l2_foveal_loss      : bool \n                                If true, for all the pixels that have pooling size 1 pixel in the \n                                largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                                In practice this gives better results when the loss is used for holography.\n    fovea_weight            : float \n                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n    use_radial_weight       : bool \n                                If True, will apply a radial weighting when calculating the difference between\n                                the source and target stats maps. This weights stats closer to the fovea more than those\n                                further away.\n    use_fullres_l0          : bool \n                                If true, stats for the lowpass residual are replaced with blurred versions\n                                of the full-resolution source and target images.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.pyramid_maker = None\n    self.alpha = alpha\n    self.real_image_width = real_image_width\n    self.real_viewing_distance = real_viewing_distance\n    self.blurs = None\n    self.n_pyramid_levels = n_pyramid_levels\n    self.n_orientations = n_orientations\n    self.mode = mode\n    self.use_l2_foveal_loss = use_l2_foveal_loss\n    self.fovea_weight = fovea_weight\n    self.use_radial_weight = use_radial_weight\n    self.use_fullres_l0 = use_fullres_l0\n    self.equi = equi\n    if self.use_fullres_l0 and self.use_l2_foveal_loss:\n        raise Exception(\n            \"Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLossUniform","title":"MetamericLossUniform","text":"

Measures metameric loss between a given image and a metamer of the given target image. This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

Source code in odak/learn/perception/metameric_loss_uniform.py
class MetamericLossUniform():\n    \"\"\"\n    Measures metameric loss between a given image and a metamer of the given target image.\n    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.\n    \"\"\"\n\n    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):\n        \"\"\"\n\n        Parameters\n        ----------\n        pooling_size            : int\n                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n        n_pyramid_levels        : int \n                                  Number of levels of the steerable pyramid. Note that the image is padded\n                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                  too high will slow down the calculation a lot.\n        n_orientations          : int \n                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                  Increasing this will increase runtime.\n\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.pyramid_maker = None\n        self.pooling_size = pooling_size\n        self.n_pyramid_levels = n_pyramid_levels\n        self.n_orientations = n_orientations\n\n    def calc_statsmaps(self, image, pooling_size):\n\n        if self.pyramid_maker is None or \\\n                self.pyramid_maker.device != self.device or \\\n                len(self.pyramid_maker.band_filters) != self.n_orientations or\\\n                self.pyramid_maker.filt_h0.size(0) != image.size(1):\n            self.pyramid_maker = SpatialSteerablePyramid(\n                use_bilinear_downup=False, n_channels=image.size(1),\n                device=self.device, n_orientations=self.n_orientations, filter_type=\"cropped\", filter_size=5)\n\n\n        def find_stats(image_pyr_level, pooling_size):\n            image_means = uniform_blur(image_pyr_level, pooling_size)\n            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)\n            image_vars = image_meansq - (image_means*image_means)\n            image_vars[image_vars < 1e-7] = 1e-7\n            image_std = torch.sqrt(image_vars)\n            if torch.any(torch.isnan(image_means)):\n                print(image_means)\n                raise Exception(\"NaN in image means!\")\n            if torch.any(torch.isnan(image_std)):\n                print(image_std)\n                raise Exception(\"NaN in image stdevs!\")\n            return image_means, image_std\n\n        output_stats = []\n        image_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n        curr_pooling_size = pooling_size\n        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)\n        output_stats.append(means)\n        output_stats.append(variances)\n\n        for l in range(0, len(image_pyramid)-1):\n            for o in range(len(image_pyramid[l]['b'])):\n                means, variances = find_stats(\n                    image_pyramid[l]['b'][o], curr_pooling_size)\n                output_stats.append(means)\n                output_stats.append(variances)\n            curr_pooling_size /= 2\n\n        output_stats.append(image_pyramid[-1][\"l\"])\n        return output_stats\n\n    def metameric_loss_stats(self, statsmap_a, statsmap_b):\n        loss = 0.0\n        for a, b in zip(statsmap_a, statsmap_b):\n            loss += torch.nn.MSELoss()(a, b)\n        loss /= len(statsmap_a)\n        return loss\n\n    def visualise_loss_map(self, image_stats):\n        loss_map = torch.zeros(image_stats[0].size()[-2:])\n        for i in range(len(image_stats)):\n            stats = image_stats[i]\n            target_stats = self.target_stats[i]\n            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))\n            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(\n            ), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            loss_map += stat_mse_map[0, 0, ...]\n        self.loss_map = loss_map\n\n    def __call__(self, image, target, image_colorspace=\"RGB\", visualise_loss=False):\n        \"\"\" \n        Calculates the Metameric Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        image_colorspace    : str\n                                The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                                accepted values: RGB, YCrCb.\n        visualise_loss      : bool\n                                Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamericLossUniform\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n        # If input is RGB, convert to YCrCb.\n        if image.size(1) == 3 and image_colorspace == \"RGB\":\n            image = rgb_2_ycrcb(image)\n            target = rgb_2_ycrcb(target)\n        if self.target is None:\n            self.target = torch.zeros(target.shape).to(target.device)\n        if type(target) == type(self.target):\n            if not torch.all(torch.eq(target, self.target)):\n                self.target = target.detach().clone()\n                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)\n                self.target = target.detach().clone()\n            image_stats = self.calc_statsmaps(image, self.pooling_size)\n\n            if visualise_loss:\n                self.visualise_loss_map(image_stats)\n            loss = self.metameric_loss_stats(\n                image_stats, self.target_stats)\n            return loss\n        else:\n            raise Exception(\"Target of incorrect type\")\n\n    def gen_metamer(self, image):\n        \"\"\" \n        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n        This function can be used on its own to generate a metamer for a desired image.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n\n        Returns\n        -------\n        metamer : torch.tensor\n                  The generated metamer image\n        \"\"\"\n        image = rgb_2_ycrcb(image)\n        image_size = image.size()\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n\n        target_stats = self.calc_statsmaps(\n            image, self.pooling_size)\n        target_means = target_stats[::2]\n        target_stdevs = target_stats[1::2]\n        torch.manual_seed(0)\n        noise_image = torch.rand_like(image)\n        noise_pyramid = self.pyramid_maker.construct_pyramid(\n            noise_image, self.n_pyramid_levels)\n        input_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n\n        def match_level(input_level, target_mean, target_std):\n            level = input_level.clone()\n            level -= torch.mean(level)\n            input_std = torch.sqrt(torch.mean(level * level))\n            eps = 1e-6\n            # Safeguard against divide by zero\n            input_std[input_std < eps] = eps\n            level /= input_std\n            level *= target_std\n            level += target_mean\n            return level\n\n        nbands = len(noise_pyramid[0][\"b\"])\n        noise_pyramid[0][\"h\"] = match_level(\n            noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n        for l in range(len(noise_pyramid)-1):\n            for b in range(nbands):\n                noise_pyramid[l][\"b\"][b] = match_level(\n                    noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n        noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n        metamer = self.pyramid_maker.reconstruct_from_pyramid(\n            noise_pyramid)\n        metamer = ycrcb_2_rgb(metamer)\n        # Crop to remove any padding\n        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n        return metamer\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLossUniform.__call__","title":"__call__(image, target, image_colorspace='RGB', visualise_loss=False)","text":"

Calculates the Metameric Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • image_colorspace \u2013
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                accepted values: RGB, YCrCb.\n
  • visualise_loss \u2013
                    Shows a heatmap indicating which parts of the image contributed most to the loss.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metameric_loss_uniform.py
def __call__(self, image, target, image_colorspace=\"RGB\", visualise_loss=False):\n    \"\"\" \n    Calculates the Metameric Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    image_colorspace    : str\n                            The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                            accepted values: RGB, YCrCb.\n    visualise_loss      : bool\n                            Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamericLossUniform\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n    # If input is RGB, convert to YCrCb.\n    if image.size(1) == 3 and image_colorspace == \"RGB\":\n        image = rgb_2_ycrcb(image)\n        target = rgb_2_ycrcb(target)\n    if self.target is None:\n        self.target = torch.zeros(target.shape).to(target.device)\n    if type(target) == type(self.target):\n        if not torch.all(torch.eq(target, self.target)):\n            self.target = target.detach().clone()\n            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)\n            self.target = target.detach().clone()\n        image_stats = self.calc_statsmaps(image, self.pooling_size)\n\n        if visualise_loss:\n            self.visualise_loss_map(image_stats)\n        loss = self.metameric_loss_stats(\n            image_stats, self.target_stats)\n        return loss\n    else:\n        raise Exception(\"Target of incorrect type\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLossUniform.__init__","title":"__init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2)","text":"

Parameters:

  • pooling_size \u2013
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n
  • n_pyramid_levels \u2013
                      Number of levels of the steerable pyramid. Note that the image is padded\n                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                  too high will slow down the calculation a lot.\n
  • n_orientations \u2013
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                  Increasing this will increase runtime.\n
Source code in odak/learn/perception/metameric_loss_uniform.py
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):\n    \"\"\"\n\n    Parameters\n    ----------\n    pooling_size            : int\n                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n    n_pyramid_levels        : int \n                              Number of levels of the steerable pyramid. Note that the image is padded\n                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                              too high will slow down the calculation a lot.\n    n_orientations          : int \n                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                              Increasing this will increase runtime.\n\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.pyramid_maker = None\n    self.pooling_size = pooling_size\n    self.n_pyramid_levels = n_pyramid_levels\n    self.n_orientations = n_orientations\n
"},{"location":"odak/learn_perception/#odak.learn.perception.MetamericLossUniform.gen_metamer","title":"gen_metamer(image)","text":"

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image \u2013
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n

Returns:

  • metamer ( tensor ) \u2013

    The generated metamer image

Source code in odak/learn/perception/metameric_loss_uniform.py
def gen_metamer(self, image):\n    \"\"\" \n    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n    This function can be used on its own to generate a metamer for a desired image.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n\n    Returns\n    -------\n    metamer : torch.tensor\n              The generated metamer image\n    \"\"\"\n    image = rgb_2_ycrcb(image)\n    image_size = image.size()\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n\n    target_stats = self.calc_statsmaps(\n        image, self.pooling_size)\n    target_means = target_stats[::2]\n    target_stdevs = target_stats[1::2]\n    torch.manual_seed(0)\n    noise_image = torch.rand_like(image)\n    noise_pyramid = self.pyramid_maker.construct_pyramid(\n        noise_image, self.n_pyramid_levels)\n    input_pyramid = self.pyramid_maker.construct_pyramid(\n        image, self.n_pyramid_levels)\n\n    def match_level(input_level, target_mean, target_std):\n        level = input_level.clone()\n        level -= torch.mean(level)\n        input_std = torch.sqrt(torch.mean(level * level))\n        eps = 1e-6\n        # Safeguard against divide by zero\n        input_std[input_std < eps] = eps\n        level /= input_std\n        level *= target_std\n        level += target_mean\n        return level\n\n    nbands = len(noise_pyramid[0][\"b\"])\n    noise_pyramid[0][\"h\"] = match_level(\n        noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n    for l in range(len(noise_pyramid)-1):\n        for b in range(nbands):\n            noise_pyramid[l][\"b\"][b] = match_level(\n                noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n    noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n    metamer = self.pyramid_maker.reconstruct_from_pyramid(\n        noise_pyramid)\n    metamer = ycrcb_2_rgb(metamer)\n    # Crop to remove any padding\n    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n    return metamer\n
"},{"location":"odak/learn_perception/#odak.learn.perception.PSNR","title":"PSNR","text":"

Bases: Module

A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class PSNR(nn.Module):\n    '''\n    A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(PSNR, self).__init__()\n\n    def forward(self, predictions, targets, peak_value = 1.0):\n        \"\"\"\n        A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        Image to be tested.\n        targets       : torch.tensor\n                        Ground truth image.\n        peak_value    : float\n                        Peak value that given tensors could have.\n\n        Returns\n        -------\n        result        : torch.tensor\n                        Peak-signal-to-noise ratio.\n        \"\"\"\n        mse = torch.mean((targets - predictions) ** 2)\n        result = 20 * torch.log10(peak_value / torch.sqrt(mse))\n        return result\n
"},{"location":"odak/learn_perception/#odak.learn.perception.PSNR.forward","title":"forward(predictions, targets, peak_value=1.0)","text":"

A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Parameters:

  • predictions \u2013
            Image to be tested.\n
  • targets \u2013
            Ground truth image.\n
  • peak_value \u2013
            Peak value that given tensors could have.\n

Returns:

  • result ( tensor ) \u2013

    Peak-signal-to-noise ratio.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets, peak_value = 1.0):\n    \"\"\"\n    A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    Image to be tested.\n    targets       : torch.tensor\n                    Ground truth image.\n    peak_value    : float\n                    Peak value that given tensors could have.\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Peak-signal-to-noise ratio.\n    \"\"\"\n    mse = torch.mean((targets - predictions) ** 2)\n    result = 20 * torch.log10(peak_value / torch.sqrt(mse))\n    return result\n
"},{"location":"odak/learn_perception/#odak.learn.perception.RadiallyVaryingBlur","title":"RadiallyVaryingBlur","text":"

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

The blur is accelerated by generating and sampling from MIP maps of the input image.

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

Source code in odak/learn/perception/radially_varying_blur.py
class RadiallyVaryingBlur():\n    \"\"\" \n\n    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().\n\n    The blur is accelerated by generating and sampling from MIP maps of the input image.\n\n    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.\n\n    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.\n\n    \"\"\"\n\n    def __init__(self):\n        self.lod_map = None\n        self.equi = None\n\n    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode=\"quadratic\", equi=False):\n        \"\"\"\n        Apply the radially varying blur to an image.\n\n        Parameters\n        ----------\n\n        image                   : torch.tensor\n                                    The image to blur, in NCHW format.\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n                                    Ignored in equirectangular mode (equi==True)\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n                                    Ignored in equirectangular mode (equi==True)\n        centre                  : tuple of floats\n                                    The centre of the radially varying blur (the gaze location).\n                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        equi                    : bool\n                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The centre argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n\n        Returns\n        -------\n\n        output                  : torch.tensor\n                                    The blurred image\n        \"\"\"\n        size = (image.size(-2), image.size(-1))\n\n        # LOD map caching\n        if self.lod_map is None or\\\n                self.size != size or\\\n                self.n_channels != image.size(1) or\\\n                self.alpha != alpha or\\\n                self.real_image_width != real_image_width or\\\n                self.real_viewing_distance != real_viewing_distance or\\\n                self.centre != centre or\\\n                self.mode != mode or\\\n                self.equi != equi:\n            if not equi:\n                self.lod_map = make_pooling_size_map_lod(\n                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)\n            else:\n                self.lod_map = make_equi_pooling_size_map_lod(\n                    centre, (image.size(-2), image.size(-1)), alpha, mode)\n            self.size = size\n            self.n_channels = image.size(1)\n            self.alpha = alpha\n            self.real_image_width = real_image_width\n            self.real_viewing_distance = real_viewing_distance\n            self.centre = centre\n            self.lod_map = self.lod_map.to(image.device)\n            self.lod_fraction = torch.fmod(self.lod_map, 1.0)\n            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(\n                1, image.size(1), 1, 1)\n            self.mode = mode\n            self.equi = equi\n\n        if self.lod_map.device != image.device:\n            self.lod_map = self.lod_map.to(image.device)\n        if self.lod_fraction.device != image.device:\n            self.lod_fraction = self.lod_fraction.to(image.device)\n\n        mipmap = [image]\n        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:\n            mipmap.append(torch.nn.functional.interpolate(\n                mipmap[-1], scale_factor=0.5, mode=\"area\", recompute_scale_factor=False))\n        if mipmap[-1].size(-1) == 2:\n            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]\n            mipmap.append(final_mip)\n        if mipmap[-1].size(-2) == 2:\n            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]\n            mipmap.append(final_mip)\n\n        for l in range(len(mipmap)):\n            if l == len(mipmap)-1:\n                mipmap[l] = mipmap[l] * \\\n                    torch.ones(image.size(), device=image.device)\n            else:\n                for l2 in range(l-1, -1, -1):\n                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(\n                        image.size(-2), image.size(-1)), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n\n        output = torch.zeros(image.size(), device=image.device)\n        for l in range(len(mipmap)):\n            if l == 0:\n                mask = self.lod_map < (l+1)\n            elif l == len(mipmap)-1:\n                mask = self.lod_map >= l\n            else:\n                mask = torch.logical_and(\n                    self.lod_map >= l, self.lod_map < (l+1))\n\n            if l == len(mipmap)-1:\n                blended_levels = mipmap[l]\n            else:\n                blended_levels = (1 - self.lod_fraction) * \\\n                    mipmap[l] + self.lod_fraction*mipmap[l+1]\n            mask = mask[None, None, ...]\n            mask = mask.repeat(1, image.size(1), 1, 1)\n            output[mask] = blended_levels[mask]\n\n        return output\n
"},{"location":"odak/learn_perception/#odak.learn.perception.RadiallyVaryingBlur.blur","title":"blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False)","text":"

Apply the radially varying blur to an image.

Parameters:

  • image \u2013
                        The image to blur, in NCHW format.\n
  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n                    Ignored in equirectangular mode (equi==True)\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n                    Ignored in equirectangular mode (equi==True)\n
  • centre \u2013
                        The centre of the radially varying blur (the gaze location).\n                    Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • equi \u2013
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The centre argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n

Returns:

  • output ( tensor ) \u2013

    The blurred image

Source code in odak/learn/perception/radially_varying_blur.py
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode=\"quadratic\", equi=False):\n    \"\"\"\n    Apply the radially varying blur to an image.\n\n    Parameters\n    ----------\n\n    image                   : torch.tensor\n                                The image to blur, in NCHW format.\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n                                Ignored in equirectangular mode (equi==True)\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n                                Ignored in equirectangular mode (equi==True)\n    centre                  : tuple of floats\n                                The centre of the radially varying blur (the gaze location).\n                                Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    equi                    : bool\n                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The centre argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n\n    Returns\n    -------\n\n    output                  : torch.tensor\n                                The blurred image\n    \"\"\"\n    size = (image.size(-2), image.size(-1))\n\n    # LOD map caching\n    if self.lod_map is None or\\\n            self.size != size or\\\n            self.n_channels != image.size(1) or\\\n            self.alpha != alpha or\\\n            self.real_image_width != real_image_width or\\\n            self.real_viewing_distance != real_viewing_distance or\\\n            self.centre != centre or\\\n            self.mode != mode or\\\n            self.equi != equi:\n        if not equi:\n            self.lod_map = make_pooling_size_map_lod(\n                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)\n        else:\n            self.lod_map = make_equi_pooling_size_map_lod(\n                centre, (image.size(-2), image.size(-1)), alpha, mode)\n        self.size = size\n        self.n_channels = image.size(1)\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.centre = centre\n        self.lod_map = self.lod_map.to(image.device)\n        self.lod_fraction = torch.fmod(self.lod_map, 1.0)\n        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(\n            1, image.size(1), 1, 1)\n        self.mode = mode\n        self.equi = equi\n\n    if self.lod_map.device != image.device:\n        self.lod_map = self.lod_map.to(image.device)\n    if self.lod_fraction.device != image.device:\n        self.lod_fraction = self.lod_fraction.to(image.device)\n\n    mipmap = [image]\n    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:\n        mipmap.append(torch.nn.functional.interpolate(\n            mipmap[-1], scale_factor=0.5, mode=\"area\", recompute_scale_factor=False))\n    if mipmap[-1].size(-1) == 2:\n        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]\n        mipmap.append(final_mip)\n    if mipmap[-1].size(-2) == 2:\n        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]\n        mipmap.append(final_mip)\n\n    for l in range(len(mipmap)):\n        if l == len(mipmap)-1:\n            mipmap[l] = mipmap[l] * \\\n                torch.ones(image.size(), device=image.device)\n        else:\n            for l2 in range(l-1, -1, -1):\n                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(\n                    image.size(-2), image.size(-1)), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n\n    output = torch.zeros(image.size(), device=image.device)\n    for l in range(len(mipmap)):\n        if l == 0:\n            mask = self.lod_map < (l+1)\n        elif l == len(mipmap)-1:\n            mask = self.lod_map >= l\n        else:\n            mask = torch.logical_and(\n                self.lod_map >= l, self.lod_map < (l+1))\n\n        if l == len(mipmap)-1:\n            blended_levels = mipmap[l]\n        else:\n            blended_levels = (1 - self.lod_fraction) * \\\n                mipmap[l] + self.lod_fraction*mipmap[l+1]\n        mask = mask[None, None, ...]\n        mask = mask.repeat(1, image.size(1), 1, 1)\n        output[mask] = blended_levels[mask]\n\n    return output\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SSIM","title":"SSIM","text":"

Bases: Module

A class to calculate structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class SSIM(nn.Module):\n    '''\n    A class to calculate structural similarity index of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(SSIM, self).__init__()\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions : torch.tensor\n                      The predicted images.\n        targets     : torch.tensor\n                      The ground truth images.\n\n        Returns\n        -------\n        result      : torch.tensor \n                      The computed SSIM value if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            from torchmetrics.functional.image import structural_similarity_index_measure\n            if len(predictions.shape) == 3:\n                predictions = predictions.unsqueeze(0)\n                targets = targets.unsqueeze(0)\n            l_SSIM = structural_similarity_index_measure(predictions, targets)\n            return l_SSIM\n        except Exception as e:\n            logging.warning('SSIM failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SSIM.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions (tensor) \u2013
          The predicted images.\n
  • targets \u2013
          The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions : torch.tensor\n                  The predicted images.\n    targets     : torch.tensor\n                  The ground truth images.\n\n    Returns\n    -------\n    result      : torch.tensor \n                  The computed SSIM value if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        from torchmetrics.functional.image import structural_similarity_index_measure\n        if len(predictions.shape) == 3:\n            predictions = predictions.unsqueeze(0)\n            targets = targets.unsqueeze(0)\n        l_SSIM = structural_similarity_index_measure(predictions, targets)\n        return l_SSIM\n    except Exception as e:\n        logging.warning('SSIM failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SpatialSteerablePyramid","title":"SpatialSteerablePyramid","text":"

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) as opposed to multiplication in the Fourier domain. This has a number of optimisations over previous implementations that increase efficiency, but introduce some reconstruction error.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
class SpatialSteerablePyramid():\n    \"\"\"\n    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)\n    as opposed to multiplication in the Fourier domain.\n    This has a number of optimisations over previous implementations that increase efficiency, but introduce some\n    reconstruction error.\n    \"\"\"\n\n\n    def __init__(self, use_bilinear_downup=True, n_channels=1,\n                 filter_size=9, n_orientations=6, filter_type=\"full\",\n                 device=torch.device('cpu')):\n        \"\"\"\n        Parameters\n        ----------\n\n        use_bilinear_downup     : bool\n                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                                    of applying a large lowpass kernel and sampling even rows/columns\n        n_channels              : int\n                                    Number of channels in the input images (e.g. 3 for RGB input)\n        filter_size             : int\n                                    Desired size of filters (e.g. 3 will use 3x3 filters).\n        n_orientations          : int\n                                    Number of oriented bands in each level of the pyramid.\n        filter_type             : str\n                                    This can be used to select smaller filters than the original ones if desired.\n                                    full: Original filter sizes\n                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n        device                  : torch.device\n                                    torch device the input images will be supplied from.\n        \"\"\"\n        self.use_bilinear_downup = use_bilinear_downup\n        self.device = device\n\n        filters = get_steerable_pyramid_filters(\n            filter_size, n_orientations, filter_type)\n\n        def make_pad(filter):\n            filter_size = filter.size(-1)\n            pad_amt = (filter_size-1) // 2\n            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))\n\n        if not self.use_bilinear_downup:\n            self.filt_l = filters[\"l\"].to(device)\n            self.pad_l = make_pad(self.filt_l)\n        self.filt_l0 = filters[\"l0\"].to(device)\n        self.pad_l0 = make_pad(self.filt_l0)\n        self.filt_h0 = filters[\"h0\"].to(device)\n        self.pad_h0 = make_pad(self.filt_h0)\n        for b in range(len(filters[\"b\"])):\n            filters[\"b\"][b] = filters[\"b\"][b].to(device)\n        self.band_filters = filters[\"b\"]\n        self.pad_b = make_pad(self.band_filters[0])\n\n        if n_channels != 1:\n            def add_channels_to_filter(filter):\n                padded = torch.zeros(n_channels, n_channels, filter.size()[\n                                     2], filter.size()[3]).to(device)\n                for channel in range(n_channels):\n                    padded[channel, channel, :, :] = filter\n                return padded\n            self.filt_h0 = add_channels_to_filter(self.filt_h0)\n            for b in range(len(self.band_filters)):\n                self.band_filters[b] = add_channels_to_filter(\n                    self.band_filters[b])\n            self.filt_l0 = add_channels_to_filter(self.filt_l0)\n            if not self.use_bilinear_downup:\n                self.filt_l = add_channels_to_filter(self.filt_l)\n\n    def construct_pyramid(self, image, n_levels, multiple_highpass=False):\n        \"\"\"\n        Constructs and returns a steerable pyramid for the provided image.\n\n        Parameters\n        ----------\n\n        image               : torch.tensor\n                                The input image, in NCHW format. The number of channels C should match num_channels\n                                when the pyramid maker was created.\n        n_levels            : int\n                                Number of levels in the constructed steerable pyramid.\n        multiple_highpass   : bool\n                                If true, computes a highpass for each level of the pyramid.\n                                These extra levels are redundant (not used for reconstruction).\n\n        Returns\n        -------\n\n        pyramid             : list of dicts of torch.tensor\n                                The computed steerable pyramid.\n                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.\n                                Each level is stored as a dict, with the following keys:\n                                \"h\" Highpass residual\n                                \"l\" Lowpass residual\n                                \"b\" Oriented bands (a list of torch.tensor)\n        \"\"\"\n        pyramid = []\n\n        # Make level 0, containing highpass, lowpass and the bands\n        level0 = {}\n        level0['h'] = torch.nn.functional.conv2d(\n            self.pad_h0(image), self.filt_h0)\n        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n        level0['l'] = lowpass.clone()\n        bands = []\n        for filt_b in self.band_filters:\n            bands.append(torch.nn.functional.conv2d(\n                self.pad_b(lowpass), filt_b))\n        level0['b'] = bands\n        pyramid.append(level0)\n\n        # Make intermediate levels\n        for l in range(n_levels-2):\n            level = {}\n            if self.use_bilinear_downup:\n                lowpass = torch.nn.functional.interpolate(\n                    lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n            else:\n                lowpass = torch.nn.functional.conv2d(\n                    self.pad_l(lowpass), self.filt_l)\n                lowpass = lowpass[:, :, ::2, ::2]\n            level['l'] = lowpass.clone()\n            bands = []\n            for filt_b in self.band_filters:\n                bands.append(torch.nn.functional.conv2d(\n                    self.pad_b(lowpass), filt_b))\n            level['b'] = bands\n            if multiple_highpass:\n                level['h'] = torch.nn.functional.conv2d(\n                    self.pad_h0(lowpass), self.filt_h0)\n            pyramid.append(level)\n\n        # Make final level (lowpass residual)\n        level = {}\n        if self.use_bilinear_downup:\n            lowpass = torch.nn.functional.interpolate(\n                lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n        else:\n            lowpass = torch.nn.functional.conv2d(\n                self.pad_l(lowpass), self.filt_l)\n            lowpass = lowpass[:, :, ::2, ::2]\n        level['l'] = lowpass\n        pyramid.append(level)\n\n        return pyramid\n\n    def reconstruct_from_pyramid(self, pyramid):\n        \"\"\"\n        Reconstructs an input image from a steerable pyramid.\n\n        Parameters\n        ----------\n\n        pyramid : list of dicts of torch.tensor\n                    The steerable pyramid.\n                    Should be in the same format as output by construct_steerable_pyramid().\n                    The number of channels should match num_channels when the pyramid maker was created.\n\n        Returns\n        -------\n\n        image   : torch.tensor\n                    The reconstructed image, in NCHW format.         \n        \"\"\"\n        def upsample(image, size):\n            if self.use_bilinear_downup:\n                return torch.nn.functional.interpolate(image, size=size, mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            else:\n                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[\n                                    2]*2, image.size()[3]*2)).to(self.device)\n                zeros[:, :, ::2, ::2] = image\n                zeros = torch.nn.functional.conv2d(\n                    self.pad_l(zeros), self.filt_l)\n                return zeros\n\n        image = pyramid[-1]['l']\n        for level in reversed(pyramid[:-1]):\n            image = upsample(image, level['b'][0].size()[2:])\n            for b in range(len(level['b'])):\n                b_filtered = torch.nn.functional.conv2d(\n                    self.pad_b(level['b'][b]), -self.band_filters[b])\n                image += b_filtered\n\n        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n        image += torch.nn.functional.conv2d(\n            self.pad_h0(pyramid[0]['h']), self.filt_h0)\n\n        return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SpatialSteerablePyramid.__init__","title":"__init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu'))","text":"

Parameters:

  • use_bilinear_downup \u2013
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                    of applying a large lowpass kernel and sampling even rows/columns\n
  • n_channels \u2013
                        Number of channels in the input images (e.g. 3 for RGB input)\n
  • filter_size \u2013
                        Desired size of filters (e.g. 3 will use 3x3 filters).\n
  • n_orientations \u2013
                        Number of oriented bands in each level of the pyramid.\n
  • filter_type \u2013
                        This can be used to select smaller filters than the original ones if desired.\n                    full: Original filter sizes\n                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n
  • device \u2013
                        torch device the input images will be supplied from.\n
Source code in odak/learn/perception/spatial_steerable_pyramid.py
def __init__(self, use_bilinear_downup=True, n_channels=1,\n             filter_size=9, n_orientations=6, filter_type=\"full\",\n             device=torch.device('cpu')):\n    \"\"\"\n    Parameters\n    ----------\n\n    use_bilinear_downup     : bool\n                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                                of applying a large lowpass kernel and sampling even rows/columns\n    n_channels              : int\n                                Number of channels in the input images (e.g. 3 for RGB input)\n    filter_size             : int\n                                Desired size of filters (e.g. 3 will use 3x3 filters).\n    n_orientations          : int\n                                Number of oriented bands in each level of the pyramid.\n    filter_type             : str\n                                This can be used to select smaller filters than the original ones if desired.\n                                full: Original filter sizes\n                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n    device                  : torch.device\n                                torch device the input images will be supplied from.\n    \"\"\"\n    self.use_bilinear_downup = use_bilinear_downup\n    self.device = device\n\n    filters = get_steerable_pyramid_filters(\n        filter_size, n_orientations, filter_type)\n\n    def make_pad(filter):\n        filter_size = filter.size(-1)\n        pad_amt = (filter_size-1) // 2\n        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))\n\n    if not self.use_bilinear_downup:\n        self.filt_l = filters[\"l\"].to(device)\n        self.pad_l = make_pad(self.filt_l)\n    self.filt_l0 = filters[\"l0\"].to(device)\n    self.pad_l0 = make_pad(self.filt_l0)\n    self.filt_h0 = filters[\"h0\"].to(device)\n    self.pad_h0 = make_pad(self.filt_h0)\n    for b in range(len(filters[\"b\"])):\n        filters[\"b\"][b] = filters[\"b\"][b].to(device)\n    self.band_filters = filters[\"b\"]\n    self.pad_b = make_pad(self.band_filters[0])\n\n    if n_channels != 1:\n        def add_channels_to_filter(filter):\n            padded = torch.zeros(n_channels, n_channels, filter.size()[\n                                 2], filter.size()[3]).to(device)\n            for channel in range(n_channels):\n                padded[channel, channel, :, :] = filter\n            return padded\n        self.filt_h0 = add_channels_to_filter(self.filt_h0)\n        for b in range(len(self.band_filters)):\n            self.band_filters[b] = add_channels_to_filter(\n                self.band_filters[b])\n        self.filt_l0 = add_channels_to_filter(self.filt_l0)\n        if not self.use_bilinear_downup:\n            self.filt_l = add_channels_to_filter(self.filt_l)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SpatialSteerablePyramid.construct_pyramid","title":"construct_pyramid(image, n_levels, multiple_highpass=False)","text":"

Constructs and returns a steerable pyramid for the provided image.

Parameters:

  • image \u2013
                    The input image, in NCHW format. The number of channels C should match num_channels\n                when the pyramid maker was created.\n
  • n_levels \u2013
                    Number of levels in the constructed steerable pyramid.\n
  • multiple_highpass \u2013
                    If true, computes a highpass for each level of the pyramid.\n                These extra levels are redundant (not used for reconstruction).\n

Returns:

  • pyramid ( list of dicts of torch.tensor ) \u2013

    The computed steerable pyramid. Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. Each level is stored as a dict, with the following keys: \"h\" Highpass residual \"l\" Lowpass residual \"b\" Oriented bands (a list of torch.tensor)

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def construct_pyramid(self, image, n_levels, multiple_highpass=False):\n    \"\"\"\n    Constructs and returns a steerable pyramid for the provided image.\n\n    Parameters\n    ----------\n\n    image               : torch.tensor\n                            The input image, in NCHW format. The number of channels C should match num_channels\n                            when the pyramid maker was created.\n    n_levels            : int\n                            Number of levels in the constructed steerable pyramid.\n    multiple_highpass   : bool\n                            If true, computes a highpass for each level of the pyramid.\n                            These extra levels are redundant (not used for reconstruction).\n\n    Returns\n    -------\n\n    pyramid             : list of dicts of torch.tensor\n                            The computed steerable pyramid.\n                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.\n                            Each level is stored as a dict, with the following keys:\n                            \"h\" Highpass residual\n                            \"l\" Lowpass residual\n                            \"b\" Oriented bands (a list of torch.tensor)\n    \"\"\"\n    pyramid = []\n\n    # Make level 0, containing highpass, lowpass and the bands\n    level0 = {}\n    level0['h'] = torch.nn.functional.conv2d(\n        self.pad_h0(image), self.filt_h0)\n    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n    level0['l'] = lowpass.clone()\n    bands = []\n    for filt_b in self.band_filters:\n        bands.append(torch.nn.functional.conv2d(\n            self.pad_b(lowpass), filt_b))\n    level0['b'] = bands\n    pyramid.append(level0)\n\n    # Make intermediate levels\n    for l in range(n_levels-2):\n        level = {}\n        if self.use_bilinear_downup:\n            lowpass = torch.nn.functional.interpolate(\n                lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n        else:\n            lowpass = torch.nn.functional.conv2d(\n                self.pad_l(lowpass), self.filt_l)\n            lowpass = lowpass[:, :, ::2, ::2]\n        level['l'] = lowpass.clone()\n        bands = []\n        for filt_b in self.band_filters:\n            bands.append(torch.nn.functional.conv2d(\n                self.pad_b(lowpass), filt_b))\n        level['b'] = bands\n        if multiple_highpass:\n            level['h'] = torch.nn.functional.conv2d(\n                self.pad_h0(lowpass), self.filt_h0)\n        pyramid.append(level)\n\n    # Make final level (lowpass residual)\n    level = {}\n    if self.use_bilinear_downup:\n        lowpass = torch.nn.functional.interpolate(\n            lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n    else:\n        lowpass = torch.nn.functional.conv2d(\n            self.pad_l(lowpass), self.filt_l)\n        lowpass = lowpass[:, :, ::2, ::2]\n    level['l'] = lowpass\n    pyramid.append(level)\n\n    return pyramid\n
"},{"location":"odak/learn_perception/#odak.learn.perception.SpatialSteerablePyramid.reconstruct_from_pyramid","title":"reconstruct_from_pyramid(pyramid)","text":"

Reconstructs an input image from a steerable pyramid.

Parameters:

  • pyramid (list of dicts of torch.tensor) \u2013
        The steerable pyramid.\n    Should be in the same format as output by construct_steerable_pyramid().\n    The number of channels should match num_channels when the pyramid maker was created.\n

Returns:

  • image ( tensor ) \u2013

    The reconstructed image, in NCHW format.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def reconstruct_from_pyramid(self, pyramid):\n    \"\"\"\n    Reconstructs an input image from a steerable pyramid.\n\n    Parameters\n    ----------\n\n    pyramid : list of dicts of torch.tensor\n                The steerable pyramid.\n                Should be in the same format as output by construct_steerable_pyramid().\n                The number of channels should match num_channels when the pyramid maker was created.\n\n    Returns\n    -------\n\n    image   : torch.tensor\n                The reconstructed image, in NCHW format.         \n    \"\"\"\n    def upsample(image, size):\n        if self.use_bilinear_downup:\n            return torch.nn.functional.interpolate(image, size=size, mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n        else:\n            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[\n                                2]*2, image.size()[3]*2)).to(self.device)\n            zeros[:, :, ::2, ::2] = image\n            zeros = torch.nn.functional.conv2d(\n                self.pad_l(zeros), self.filt_l)\n            return zeros\n\n    image = pyramid[-1]['l']\n    for level in reversed(pyramid[:-1]):\n        image = upsample(image, level['b'][0].size()[2:])\n        for b in range(len(level['b'])):\n            b_filtered = torch.nn.functional.conv2d(\n                self.pad_b(level['b'][b]), -self.band_filters[b])\n            image += b_filtered\n\n    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n    image += torch.nn.functional.conv2d(\n        self.pad_h0(pyramid[0]['h']), self.filt_h0)\n\n    return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs","title":"display_color_hvs","text":"Source code in odak/learn/perception/color_conversion.py
class display_color_hvs():\n\n    def __init__(\n                 self,\n                 resolution = [1920, 1080],\n                 distance_from_screen = 800,\n                 pixel_pitch = 0.311,\n                 read_spectrum = 'tensor',\n                 primaries_spectrum = torch.rand(3, 301),\n                 device = torch.device('cpu')):\n        '''\n        Parameters\n        ----------\n        resolution                  : list\n                                      Resolution of the display in pixels.\n        distance_from_screen        : int\n                                      Distance from the screen in mm.\n        pixel_pitch                 : float\n                                      Pixel pitch of the display in mm.\n        read_spectrum               : str\n                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n        device                      : torch.device\n                                      Device to run the code on. Default is None which means the code will run on CPU.\n\n        '''\n        self.device = device\n        self.read_spectrum = read_spectrum\n        self.primaries_spectrum = primaries_spectrum.to(self.device)\n        self.resolution = resolution\n        self.distance_from_screen = distance_from_screen\n        self.pixel_pitch = pixel_pitch\n        self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()\n        self.lms_tensor = self.construct_matrix_lms(\n                                                    self.l_normalized,\n                                                    self.m_normalized,\n                                                    self.s_normalized\n                                                   )   \n        self.primaries_tensor = self.construct_matrix_primaries(\n                                                                self.l_normalized,\n                                                                self.m_normalized,\n                                                                self.s_normalized\n                                                               )   \n        return\n\n\n    def __call__(self, input_image, ground_truth, gaze=None):\n        \"\"\"\n        Evaluating an input image against a target ground truth image for a given gaze of a viewer.\n        \"\"\"\n        lms_image_second = self.primaries_to_lms(input_image.to(self.device))\n        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))\n        lms_image_third = self.second_to_third_stage(lms_image_second)\n        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)\n        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)\n        return loss_metamer_color\n\n\n    def initialize_cones_normalized(self):\n        \"\"\"\n        Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. \n\n        Returns\n        -------\n        l_cone_n                     : torch.tensor\n                                       Normalised L cone distribution.\n        m_cone_n                     : torch.tensor\n                                       Normalised M cone distribution.\n        s_cone_n                     : torch.tensor\n                                       Normalised S cone distribution.\n        \"\"\"\n        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n        dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))\n        dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))\n        dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))\n\n        l_cone_n = dist_l / dist_l.max()\n        m_cone_n = dist_m / dist_m.max()\n        s_cone_n = dist_s / dist_s.max()\n        return l_cone_n, m_cone_n, s_cone_n\n\n\n    def initialize_rgb_backlight_spectrum(self):\n        \"\"\"\n        Internal function to initialize baclight spectrum for color primaries. \n\n        Returns\n        -------\n        red_spectrum                 : torch.tensor\n                                       Normalised backlight spectrum for red color primary.\n        green_spectrum               : torch.tensor\n                                       Normalised backlight spectrum for green color primary.\n        blue_spectrum                : torch.tensor\n                                       Normalised backlight spectrum for blue color primary.\n        \"\"\"\n        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n        red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))\n        green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))\n        blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))\n\n        red_spectrum = red_spectrum / red_spectrum.max()\n        green_spectrum = green_spectrum / green_spectrum.max()\n        blue_spectrum = blue_spectrum / blue_spectrum.max()\n\n        return red_spectrum, green_spectrum, blue_spectrum\n\n\n    def initialize_random_spectrum_normalized(self, dataset):\n        \"\"\"\n        Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. \n\n        Parameters\n        ----------\n        dataset                                : torch.tensor \n                                                 spectrum value against wavelength \n        \"\"\"\n        dataset = torch.swapaxes(dataset, 0, 1)\n        x_spectrum = torch.linspace(400, 700, steps = 301) - 550\n        y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))\n        max_spectrum = torch.max(y_spectrum)\n        y_spectrum /= max_spectrum\n\n        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \\\n            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))\n\n        def function(x, weights): \n            return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])\n\n        weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)\n        optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)\n\n        def closure():\n            optimizer.zero_grad()\n            output = function(x_spectrum, weights)\n            loss = F.mse_loss(output, y_spectrum)\n            loss.backward()\n            return loss\n        optimizer.step(closure)\n        spectrum = function(x_spectrum, weights)\n        return spectrum.detach().to(self.device)\n\n\n    def display_spectrum_response(wavelength, function):\n        \"\"\"\n        Internal function to provide light spectrum response at particular wavelength\n\n        Parameters\n        ----------\n        wavelength                          : torch.tensor\n                                              Wavelength in nm [400...700]\n        function                            : torch.tensor\n                                              Display light spectrum distribution function\n\n        Returns\n        -------\n        ligth_response_dict                  : float\n                                               Display light spectrum response value\n        \"\"\"\n        wavelength = int(round(wavelength, 0))\n        if wavelength >= 400 and wavelength <= 700:\n            return function[wavelength - 400].item()\n        elif wavelength < 400:\n            return function[0].item()\n        else:\n            return function[300].item()\n\n\n    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):\n        \"\"\"\n        Internal function to calculate cone response at particular light spectrum. \n\n        Parameters\n        ----------\n        cone_spectrum                         : torch.tensor\n                                                Spectrum, Wavelength [2,300] tensor \n        light_spectrum                        : torch.tensor\n                                                Spectrum, Wavelength [2,300] tensor \n\n\n        Returns\n        -------\n        response_to_spectrum                  : float\n                                                Response of cone to light spectrum [1x1] \n        \"\"\"\n        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)\n        response_to_spectrum = torch.sum(response_to_spectrum)\n        return response_to_spectrum.item()\n\n\n    def construct_matrix_lms(self, l_response, m_response, s_response):\n        '''\n        Internal function to calculate cone  response at particular light spectrum. \n\n        Parameters\n        ----------\n        l_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        m_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        s_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n        Returns\n        -------\n        lms_image_tensor                      : torch.tensor\n                                                3x3 LMSrgb tensor\n\n        '''\n        if self.read_spectrum == 'tensor':\n            logging.warning('Tensor primary spectrum is used')\n            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n        else:\n            logging.warning(\"No Spectrum data is provided\")\n\n        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)\n        for i in range(self.primaries_spectrum.shape[0]):\n            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])\n            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])\n            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) \n        return self.lms_tensor    \n\n\n    def construct_matrix_primaries(self, l_response, m_response, s_response):\n        '''\n        Internal function to calculate cone  response at particular light spectrum. \n\n        Parameters\n        ----------\n        l_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        m_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        s_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n        Returns\n        -------\n        lms_image_tensor                      : torch.tensor\n                                                3x3 LMSrgb tensor\n\n        '''\n        if self.read_spectrum == 'tensor':\n            logging.warning('Tensor primary spectrum is used')\n            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n        else:\n            logging.warning(\"No Spectrum data is provided\")\n\n        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)\n        for i in range(self.primaries_spectrum.shape[0]):\n            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(\n                                                                         l_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        )\n            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(\n                                                                         m_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        )\n            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(\n                                                                         s_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        ) \n        return self.primaries_tensor    \n\n\n    def primaries_to_lms(self, primaries):\n        \"\"\"\n        Internal function to convert primaries space to LMS space \n\n        Parameters\n        ----------\n        primaries                              : torch.tensor\n                                                 Primaries data to be transformed to LMS space [BxPHxW]\n\n\n        Returns\n        -------\n        lms_color                              : torch.tensor\n                                                 LMS data transformed from Primaries space [BxPxHxW]\n        \"\"\"                \n        primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)\n        lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)\n        lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)\n        return lms_color\n\n\n    def lms_to_primaries(self, lms_color_tensor):\n        \"\"\"\n        Internal function to convert LMS image to primaries space\n\n        Parameters\n        ----------\n        lms_color_tensor                        : torch.tensor\n                                                  LMS data to be transformed to primaries space [Bx3xHxW]\n\n\n        Returns\n        -------\n        primaries                              : torch.tensor\n                                               : Primaries data transformed from LMS space [BxPxHxW]\n        \"\"\"\n        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)\n        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)\n        unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))\n        converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())\n        primaries = unflatten(converted_unflatten)     \n        primaries = primaries.permute(0, 3, 1, 2)   \n        return primaries\n\n\n    def second_to_third_stage(self, lms_image):\n        '''\n        This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], \n        See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.\n\n        Parameters\n        ----------\n        lms_image                             : torch.tensor\n                                                 Image data at LMS space (second stage)\n\n        Returns\n        -------\n        third_stage                            : torch.tensor\n                                                 Image data at LMS space (third stage)\n\n        '''\n        third_stage = torch.zeros_like(lms_image)\n        third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]\n        third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]\n        third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]\n        return third_stage\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.__call__","title":"__call__(input_image, ground_truth, gaze=None)","text":"

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

Source code in odak/learn/perception/color_conversion.py
def __call__(self, input_image, ground_truth, gaze=None):\n    \"\"\"\n    Evaluating an input image against a target ground truth image for a given gaze of a viewer.\n    \"\"\"\n    lms_image_second = self.primaries_to_lms(input_image.to(self.device))\n    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))\n    lms_image_third = self.second_to_third_stage(lms_image_second)\n    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)\n    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)\n    return loss_metamer_color\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.__init__","title":"__init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu'))","text":"

Parameters:

  • resolution \u2013
                          Resolution of the display in pixels.\n
  • distance_from_screen \u2013
                          Distance from the screen in mm.\n
  • pixel_pitch \u2013
                          Pixel pitch of the display in mm.\n
  • read_spectrum \u2013
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n
  • device \u2013
                          Device to run the code on. Default is None which means the code will run on CPU.\n
Source code in odak/learn/perception/color_conversion.py
def __init__(\n             self,\n             resolution = [1920, 1080],\n             distance_from_screen = 800,\n             pixel_pitch = 0.311,\n             read_spectrum = 'tensor',\n             primaries_spectrum = torch.rand(3, 301),\n             device = torch.device('cpu')):\n    '''\n    Parameters\n    ----------\n    resolution                  : list\n                                  Resolution of the display in pixels.\n    distance_from_screen        : int\n                                  Distance from the screen in mm.\n    pixel_pitch                 : float\n                                  Pixel pitch of the display in mm.\n    read_spectrum               : str\n                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n    device                      : torch.device\n                                  Device to run the code on. Default is None which means the code will run on CPU.\n\n    '''\n    self.device = device\n    self.read_spectrum = read_spectrum\n    self.primaries_spectrum = primaries_spectrum.to(self.device)\n    self.resolution = resolution\n    self.distance_from_screen = distance_from_screen\n    self.pixel_pitch = pixel_pitch\n    self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()\n    self.lms_tensor = self.construct_matrix_lms(\n                                                self.l_normalized,\n                                                self.m_normalized,\n                                                self.s_normalized\n                                               )   \n    self.primaries_tensor = self.construct_matrix_primaries(\n                                                            self.l_normalized,\n                                                            self.m_normalized,\n                                                            self.s_normalized\n                                                           )   \n    return\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.cone_response_to_spectrum","title":"cone_response_to_spectrum(cone_spectrum, light_spectrum)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • cone_spectrum \u2013
                                    Spectrum, Wavelength [2,300] tensor\n
  • light_spectrum \u2013
                                    Spectrum, Wavelength [2,300] tensor\n

Returns:

  • response_to_spectrum ( float ) \u2013

    Response of cone to light spectrum [1x1]

Source code in odak/learn/perception/color_conversion.py
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):\n    \"\"\"\n    Internal function to calculate cone response at particular light spectrum. \n\n    Parameters\n    ----------\n    cone_spectrum                         : torch.tensor\n                                            Spectrum, Wavelength [2,300] tensor \n    light_spectrum                        : torch.tensor\n                                            Spectrum, Wavelength [2,300] tensor \n\n\n    Returns\n    -------\n    response_to_spectrum                  : float\n                                            Response of cone to light spectrum [1x1] \n    \"\"\"\n    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)\n    response_to_spectrum = torch.sum(response_to_spectrum)\n    return response_to_spectrum.item()\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.construct_matrix_lms","title":"construct_matrix_lms(l_response, m_response, s_response)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • m_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • s_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n

Returns:

  • lms_image_tensor ( tensor ) \u2013

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_lms(self, l_response, m_response, s_response):\n    '''\n    Internal function to calculate cone  response at particular light spectrum. \n\n    Parameters\n    ----------\n    l_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    m_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    s_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n    Returns\n    -------\n    lms_image_tensor                      : torch.tensor\n                                            3x3 LMSrgb tensor\n\n    '''\n    if self.read_spectrum == 'tensor':\n        logging.warning('Tensor primary spectrum is used')\n        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n    else:\n        logging.warning(\"No Spectrum data is provided\")\n\n    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)\n    for i in range(self.primaries_spectrum.shape[0]):\n        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])\n        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])\n        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) \n    return self.lms_tensor    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.construct_matrix_primaries","title":"construct_matrix_primaries(l_response, m_response, s_response)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • m_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • s_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n

Returns:

  • lms_image_tensor ( tensor ) \u2013

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_primaries(self, l_response, m_response, s_response):\n    '''\n    Internal function to calculate cone  response at particular light spectrum. \n\n    Parameters\n    ----------\n    l_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    m_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    s_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n    Returns\n    -------\n    lms_image_tensor                      : torch.tensor\n                                            3x3 LMSrgb tensor\n\n    '''\n    if self.read_spectrum == 'tensor':\n        logging.warning('Tensor primary spectrum is used')\n        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n    else:\n        logging.warning(\"No Spectrum data is provided\")\n\n    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)\n    for i in range(self.primaries_spectrum.shape[0]):\n        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(\n                                                                     l_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    )\n        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(\n                                                                     m_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    )\n        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(\n                                                                     s_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    ) \n    return self.primaries_tensor    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.display_spectrum_response","title":"display_spectrum_response(wavelength, function)","text":"

Internal function to provide light spectrum response at particular wavelength

Parameters:

  • wavelength \u2013
                                  Wavelength in nm [400...700]\n
  • function \u2013
                                  Display light spectrum distribution function\n

Returns:

  • ligth_response_dict ( float ) \u2013

    Display light spectrum response value

Source code in odak/learn/perception/color_conversion.py
def display_spectrum_response(wavelength, function):\n    \"\"\"\n    Internal function to provide light spectrum response at particular wavelength\n\n    Parameters\n    ----------\n    wavelength                          : torch.tensor\n                                          Wavelength in nm [400...700]\n    function                            : torch.tensor\n                                          Display light spectrum distribution function\n\n    Returns\n    -------\n    ligth_response_dict                  : float\n                                           Display light spectrum response value\n    \"\"\"\n    wavelength = int(round(wavelength, 0))\n    if wavelength >= 400 and wavelength <= 700:\n        return function[wavelength - 400].item()\n    elif wavelength < 400:\n        return function[0].item()\n    else:\n        return function[300].item()\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.initialize_cones_normalized","title":"initialize_cones_normalized()","text":"

Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values.

Returns:

  • l_cone_n ( tensor ) \u2013

    Normalised L cone distribution.

  • m_cone_n ( tensor ) \u2013

    Normalised M cone distribution.

  • s_cone_n ( tensor ) \u2013

    Normalised S cone distribution.

Source code in odak/learn/perception/color_conversion.py
def initialize_cones_normalized(self):\n    \"\"\"\n    Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. \n\n    Returns\n    -------\n    l_cone_n                     : torch.tensor\n                                   Normalised L cone distribution.\n    m_cone_n                     : torch.tensor\n                                   Normalised M cone distribution.\n    s_cone_n                     : torch.tensor\n                                   Normalised S cone distribution.\n    \"\"\"\n    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n    dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))\n    dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))\n    dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))\n\n    l_cone_n = dist_l / dist_l.max()\n    m_cone_n = dist_m / dist_m.max()\n    s_cone_n = dist_s / dist_s.max()\n    return l_cone_n, m_cone_n, s_cone_n\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.initialize_random_spectrum_normalized","title":"initialize_random_spectrum_normalized(dataset)","text":"

Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

Parameters:

  • dataset \u2013
                                     spectrum value against wavelength\n
Source code in odak/learn/perception/color_conversion.py
def initialize_random_spectrum_normalized(self, dataset):\n    \"\"\"\n    Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. \n\n    Parameters\n    ----------\n    dataset                                : torch.tensor \n                                             spectrum value against wavelength \n    \"\"\"\n    dataset = torch.swapaxes(dataset, 0, 1)\n    x_spectrum = torch.linspace(400, 700, steps = 301) - 550\n    y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))\n    max_spectrum = torch.max(y_spectrum)\n    y_spectrum /= max_spectrum\n\n    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \\\n        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))\n\n    def function(x, weights): \n        return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])\n\n    weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)\n    optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)\n\n    def closure():\n        optimizer.zero_grad()\n        output = function(x_spectrum, weights)\n        loss = F.mse_loss(output, y_spectrum)\n        loss.backward()\n        return loss\n    optimizer.step(closure)\n    spectrum = function(x_spectrum, weights)\n    return spectrum.detach().to(self.device)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.initialize_rgb_backlight_spectrum","title":"initialize_rgb_backlight_spectrum()","text":"

Internal function to initialize baclight spectrum for color primaries.

Returns:

  • red_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for red color primary.

  • green_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for green color primary.

  • blue_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for blue color primary.

Source code in odak/learn/perception/color_conversion.py
def initialize_rgb_backlight_spectrum(self):\n    \"\"\"\n    Internal function to initialize baclight spectrum for color primaries. \n\n    Returns\n    -------\n    red_spectrum                 : torch.tensor\n                                   Normalised backlight spectrum for red color primary.\n    green_spectrum               : torch.tensor\n                                   Normalised backlight spectrum for green color primary.\n    blue_spectrum                : torch.tensor\n                                   Normalised backlight spectrum for blue color primary.\n    \"\"\"\n    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n    red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))\n    green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))\n    blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))\n\n    red_spectrum = red_spectrum / red_spectrum.max()\n    green_spectrum = green_spectrum / green_spectrum.max()\n    blue_spectrum = blue_spectrum / blue_spectrum.max()\n\n    return red_spectrum, green_spectrum, blue_spectrum\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.lms_to_primaries","title":"lms_to_primaries(lms_color_tensor)","text":"

Internal function to convert LMS image to primaries space

Parameters:

  • lms_color_tensor \u2013
                                      LMS data to be transformed to primaries space [Bx3xHxW]\n

Returns:

  • primaries ( tensor ) \u2013

    : Primaries data transformed from LMS space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def lms_to_primaries(self, lms_color_tensor):\n    \"\"\"\n    Internal function to convert LMS image to primaries space\n\n    Parameters\n    ----------\n    lms_color_tensor                        : torch.tensor\n                                              LMS data to be transformed to primaries space [Bx3xHxW]\n\n\n    Returns\n    -------\n    primaries                              : torch.tensor\n                                           : Primaries data transformed from LMS space [BxPxHxW]\n    \"\"\"\n    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)\n    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)\n    unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))\n    converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())\n    primaries = unflatten(converted_unflatten)     \n    primaries = primaries.permute(0, 3, 1, 2)   \n    return primaries\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.primaries_to_lms","title":"primaries_to_lms(primaries)","text":"

Internal function to convert primaries space to LMS space

Parameters:

  • primaries \u2013
                                     Primaries data to be transformed to LMS space [BxPHxW]\n

Returns:

  • lms_color ( tensor ) \u2013

    LMS data transformed from Primaries space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def primaries_to_lms(self, primaries):\n    \"\"\"\n    Internal function to convert primaries space to LMS space \n\n    Parameters\n    ----------\n    primaries                              : torch.tensor\n                                             Primaries data to be transformed to LMS space [BxPHxW]\n\n\n    Returns\n    -------\n    lms_color                              : torch.tensor\n                                             LMS data transformed from Primaries space [BxPxHxW]\n    \"\"\"                \n    primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)\n    lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)\n    lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)\n    return lms_color\n
"},{"location":"odak/learn_perception/#odak.learn.perception.display_color_hvs.second_to_third_stage","title":"second_to_third_stage(lms_image)","text":"

This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.

Parameters:

  • lms_image \u2013
                                     Image data at LMS space (second stage)\n

Returns:

  • third_stage ( tensor ) \u2013

    Image data at LMS space (third stage)

Source code in odak/learn/perception/color_conversion.py
def second_to_third_stage(self, lms_image):\n    '''\n    This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], \n    See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.\n\n    Parameters\n    ----------\n    lms_image                             : torch.tensor\n                                             Image data at LMS space (second stage)\n\n    Returns\n    -------\n    third_stage                            : torch.tensor\n                                             Image data at LMS space (third stage)\n\n    '''\n    third_stage = torch.zeros_like(lms_image)\n    third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]\n    third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]\n    third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]\n    return third_stage\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_map","title":"color_map(input_image, target_image, model='Lab Stats')","text":"

Internal function to map the color of an image to another image. Reference: Color transfer between images, Reinhard et al., 2001.

Parameters:

  • input_image \u2013
                  Input image in RGB color space [3 x m x n].\n
  • target_image \u2013

Returns:

  • mapped_image ( Tensor ) \u2013

    Input image with the color the distribution of the target image [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def color_map(input_image, target_image, model = 'Lab Stats'):\n    \"\"\"\n    Internal function to map the color of an image to another image.\n    Reference: Color transfer between images, Reinhard et al., 2001.\n\n    Parameters\n    ----------\n    input_image         : torch.Tensor\n                          Input image in RGB color space [3 x m x n].\n    target_image        : torch.Tensor\n\n    Returns\n    -------\n    mapped_image           : torch.Tensor\n                             Input image with the color the distribution of the target image [3 x m x n].\n    \"\"\"\n    if model == 'Lab Stats':\n        lab_input = srgb_to_lab(input_image)\n        lab_target = srgb_to_lab(target_image)\n        input_mean_L = torch.mean(lab_input[0, :, :])\n        input_mean_a = torch.mean(lab_input[1, :, :])\n        input_mean_b = torch.mean(lab_input[2, :, :])\n        input_std_L = torch.std(lab_input[0, :, :])\n        input_std_a = torch.std(lab_input[1, :, :])\n        input_std_b = torch.std(lab_input[2, :, :])\n        target_mean_L = torch.mean(lab_target[0, :, :])\n        target_mean_a = torch.mean(lab_target[1, :, :])\n        target_mean_b = torch.mean(lab_target[2, :, :])\n        target_std_L = torch.std(lab_target[0, :, :])\n        target_std_a = torch.std(lab_target[1, :, :])\n        target_std_b = torch.std(lab_target[2, :, :])\n        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L\n        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a\n        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b\n        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))\n        return mapped_image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.crop_steerable_pyramid_filters","title":"crop_steerable_pyramid_filters(filters, size)","text":"

Given original 9x9 NYU filters, this crops them to the desired size. The size must be an odd number >= 3 Note this only crops the h0, l0 and band filters (not the l downsampling filter)

Parameters:

  • filters \u2013
            Filters to crop (should in format used by get_steerable_pyramid_filters.)\n
  • size \u2013
            Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.\n

Returns:

  • filters ( dict of torch.tensor ) \u2013

    The cropped filters.

Source code in odak/learn/perception/steerable_pyramid_filters.py
def crop_steerable_pyramid_filters(filters, size):\n    \"\"\"\n    Given original 9x9 NYU filters, this crops them to the desired size.\n    The size must be an odd number >= 3\n    Note this only crops the h0, l0 and band filters (not the l downsampling filter)\n\n    Parameters\n    ----------\n    filters     : dict of torch.tensor\n                    Filters to crop (should in format used by get_steerable_pyramid_filters.)\n    size        : int\n                    Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.\n\n    Returns\n    -------\n    filters     : dict of torch.tensor\n                    The cropped filters.\n    \"\"\"\n    assert(size >= 3)\n    assert(size % 2 == 1)\n    r = (size-1) // 2\n\n    def crop_filter(filter, r, normalise=True):\n        r2 = (filter.size(-1)-1)//2\n        filter = filter[:, :, r2-r:r2+r+1, r2-r:r2+r+1]\n        if normalise:\n            filter -= torch.sum(filter)\n        return filter\n\n    filters[\"h0\"] = crop_filter(filters[\"h0\"], r, normalise=False)\n    sum_l = torch.sum(filters[\"l\"])\n    filters[\"l\"] = crop_filter(filters[\"l\"], 6, normalise=False)\n    filters[\"l\"] *= sum_l / torch.sum(filters[\"l\"])\n    sum_l0 = torch.sum(filters[\"l0\"])\n    filters[\"l0\"] = crop_filter(filters[\"l0\"], 2, normalise=False)\n    filters[\"l0\"] *= sum_l0 / torch.sum(filters[\"l0\"])\n    for b in range(len(filters[\"b\"])):\n        filters[\"b\"][b] = crop_filter(filters[\"b\"][b], r, normalise=True)\n    return filters\n
"},{"location":"odak/learn_perception/#odak.learn.perception.get_steerable_pyramid_filters","title":"get_steerable_pyramid_filters(size, n_orientations, filter_type)","text":"

This returns filters for a real-valued steerable pyramid.

Parameters:

  • size \u2013
                Width of the filters (e.g. 3 will return 3x3 filters)\n
  • n_orientations \u2013
                Number of oriented band filters\n
  • filter_type \u2013
                This can be used to select between the original NYU filters and cropped or trained alternatives.\n            full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py\n            cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n            trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n

Returns:

  • filters ( dict of torch.tensor ) \u2013

    The steerable pyramid filters. Returned as a dict with the following keys: \"l\" The lowpass downsampling filter \"l0\" The lowpass residual filter \"h0\" The highpass residual filter \"b\" The band filters (a list of torch.tensor filters, one for each orientation).

Source code in odak/learn/perception/steerable_pyramid_filters.py
def get_steerable_pyramid_filters(size, n_orientations, filter_type):\n    \"\"\"\n    This returns filters for a real-valued steerable pyramid.\n\n    Parameters\n    ----------\n\n    size            : int\n                        Width of the filters (e.g. 3 will return 3x3 filters)\n    n_orientations  : int\n                        Number of oriented band filters\n    filter_type     :  str\n                        This can be used to select between the original NYU filters and cropped or trained alternatives.\n                        full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py\n                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n\n    Returns\n    -------\n    filters         : dict of torch.tensor\n                        The steerable pyramid filters. Returned as a dict with the following keys:\n                        \"l\" The lowpass downsampling filter\n                        \"l0\" The lowpass residual filter\n                        \"h0\" The highpass residual filter\n                        \"b\" The band filters (a list of torch.tensor filters, one for each orientation).\n    \"\"\"\n\n    if filter_type != \"full\" and filter_type != \"cropped\" and filter_type != \"trained\":\n        raise Exception(\n            \"Unknown filter type %s! Only filter types are full, cropped or trained.\" % filter_type)\n\n    filters = {}\n    if n_orientations == 1:\n        filters[\"l\"] = torch.tensor([\n            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -\n                1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04],\n            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -\n                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],\n            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -\n                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],\n            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,\n                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],\n            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,\n                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],\n            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,\n                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],\n            [-1.871920e-03, -6.948900e-03, -3.781600e-03, 2.449600e-02, 7.624674e-02, 1.348999e-01, 1.576508e-01,\n                1.348999e-01, 7.624674e-02, 2.449600e-02, -3.781600e-03, -6.948900e-03, -1.871920e-03],\n            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,\n                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],\n            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,\n                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],\n            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,\n                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],\n            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -\n                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],\n            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -\n                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],\n            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04]]\n        ).reshape(1, 1, 13, 13)\n        filters[\"l0\"] = torch.tensor([\n            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -\n                3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04],\n            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -\n                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],\n            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,\n                6.441488e-02, -1.344160e-02, -3.725800e-04],\n            [-3.743860e-03, -7.563200e-03, 1.524935e-01, 3.153017e-01,\n                1.524935e-01, -7.563200e-03, -3.743860e-03],\n            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,\n                6.441488e-02, -1.344160e-02, -3.725800e-04],\n            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -\n                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],\n            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04]]\n        ).reshape(1, 1, 7, 7)\n        filters[\"h0\"] = torch.tensor([\n            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -\n                2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04],\n            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -\n                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],\n            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -\n                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],\n            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -\n                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],\n            [-2.406600e-04, -3.732100e-04, -2.420138e-02, -9.623594e-02,\n                8.554893e-01, -9.623594e-02, -2.420138e-02, -3.732100e-04, -2.406600e-04],\n            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -\n                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],\n            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -\n                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],\n            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -\n                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],\n            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor([\n            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -\n            1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05,\n            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -\n            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,\n            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -\n            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,\n            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -\n            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,\n            -1.009473e-02, -9.091950e-03, -3.487008e-02, -3.158605e-02, 9.464195e-01, -\n            3.158605e-02, -3.487008e-02, -9.091950e-03, -1.009473e-02,\n            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -\n            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,\n            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -\n            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,\n            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -\n            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,\n            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 2:\n        filters[\"l\"] = torch.tensor(\n            [[-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05],\n             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,\n                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],\n             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,\n                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],\n             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -\n                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],\n             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -\n                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],\n             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,\n                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],\n             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,\n                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],\n             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,\n                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],\n             [1.262000e-03, 5.589400e-03, 5.538200e-04, -9.637220e-03, -1.648568e-02, 1.438512e-02, 9.058014e-02, 1.773651e-01, 2.120374e-01,\n                 1.773651e-01, 9.058014e-02, 1.438512e-02, -1.648568e-02, -9.637220e-03, 5.538200e-04, 5.589400e-03, 1.262000e-03],\n             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,\n                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],\n             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,\n                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],\n             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,\n                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],\n             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -\n                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],\n             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -\n                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],\n             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,\n                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],\n             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,\n                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],\n             [-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05]]\n        ).reshape(1, 1, 17, 17)\n        filters[\"l0\"] = torch.tensor(\n            [[-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05],\n             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,\n                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],\n             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -\n                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],\n             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,\n                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],\n             [2.524010e-03, 1.107620e-03, -3.297137e-02, 1.811603e-01, 4.376250e-01,\n                 1.811603e-01, -3.297137e-02, 1.107620e-03, 2.524010e-03],\n             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,\n                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],\n             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -\n                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],\n             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,\n                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],\n             [-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"h0\"] = torch.tensor(\n            [[-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04],\n             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,\n                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],\n             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -\n                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],\n             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -\n                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],\n             [-1.166810e-03, 1.098012e-02, -5.897780e-03, -1.562827e-01,\n                 7.282593e-01, -1.562827e-01, -5.897780e-03, 1.098012e-02, -1.166810e-03],\n             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -\n                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],\n             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -\n                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],\n             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,\n                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],\n             [-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor(\n            [6.125880e-03, -8.052600e-03, -2.103714e-02, -1.536890e-02, -1.851466e-02, -1.536890e-02, -2.103714e-02, -8.052600e-03, 6.125880e-03,\n             -1.287416e-02, -9.611520e-03, 1.023569e-02, 6.009450e-03, 1.872620e-03, 6.009450e-03, 1.023569e-02, -\n             9.611520e-03, -1.287416e-02,\n             -5.641530e-03, 4.168400e-03, -2.382180e-02, -5.375324e-02, -\n             2.076086e-02, -5.375324e-02, -2.382180e-02, 4.168400e-03, -5.641530e-03,\n             -8.957260e-03, -1.751170e-03, -1.836909e-02, 1.265655e-01, 2.996168e-01, 1.265655e-01, -\n             1.836909e-02, -1.751170e-03, -8.957260e-03,\n             0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,\n             8.957260e-03, 1.751170e-03, 1.836909e-02, -1.265655e-01, -\n             2.996168e-01, -1.265655e-01, 1.836909e-02, 1.751170e-03, 8.957260e-03,\n             5.641530e-03, -4.168400e-03, 2.382180e-02, 5.375324e-02, 2.076086e-02, 5.375324e-02, 2.382180e-02, -\n             4.168400e-03, 5.641530e-03,\n             1.287416e-02, 9.611520e-03, -1.023569e-02, -6.009450e-03, -\n             1.872620e-03, -6.009450e-03, -1.023569e-02, 9.611520e-03, 1.287416e-02,\n             -6.125880e-03, 8.052600e-03, 2.103714e-02, 1.536890e-02, 1.851466e-02, 1.536890e-02, 2.103714e-02, 8.052600e-03, -6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [-6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03,\n             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -\n             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,\n             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -\n             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,\n             1.536890e-02, -6.009450e-03, 5.375324e-02, -\n             1.265655e-01, 0.000000e+00, 1.265655e-01, -\n             5.375324e-02, 6.009450e-03, -1.536890e-02,\n             1.851466e-02, -1.872620e-03, 2.076086e-02, -\n             2.996168e-01, 0.000000e+00, 2.996168e-01, -\n             2.076086e-02, 1.872620e-03, -1.851466e-02,\n             1.536890e-02, -6.009450e-03, 5.375324e-02, -\n             1.265655e-01, 0.000000e+00, 1.265655e-01, -\n             5.375324e-02, 6.009450e-03, -1.536890e-02,\n             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -\n             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,\n             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -\n             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,\n             -6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 4:\n        filters[\"l\"] = torch.tensor([\n            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4,\n                1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5],\n            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,\n                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],\n            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,\n                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],\n            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -\n                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],\n            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -\n                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],\n            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,\n                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],\n            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,\n                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],\n            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,\n                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],\n            [1.2619999470E-3, 5.5893999524E-3, 5.5381999118E-4, -9.6372198313E-3, -1.6485679895E-2, 1.4385120012E-2, 9.0580143034E-2, 0.1773651242, 0.2120374441,\n                0.1773651242, 9.0580143034E-2, 1.4385120012E-2, -1.6485679895E-2, -9.6372198313E-3, 5.5381999118E-4, 5.5893999524E-3, 1.2619999470E-3],\n            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,\n                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],\n            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,\n                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],\n            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,\n                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],\n            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -\n                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],\n            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -\n                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],\n            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,\n                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],\n            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,\n                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],\n            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4, 1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5]]\n        ).reshape(1, 1, 17, 17)\n        filters[\"l0\"] = torch.tensor([\n            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4,\n                2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5],\n            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,\n                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],\n            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -\n                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],\n            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,\n                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],\n            [2.5240099058E-3, 1.1076199589E-3, -3.2971370965E-2, 0.1811603010, 0.4376249909,\n                0.1811603010, -3.2971370965E-2, 1.1076199589E-3, 2.5240099058E-3],\n            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,\n                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],\n            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -\n                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],\n            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,\n                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],\n            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4, 2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"h0\"] = torch.tensor([\n            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3,\n                2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4],\n            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,\n                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],\n            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -\n                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],\n            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -\n                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],\n            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,\n                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],\n            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -\n                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],\n            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -\n                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],\n            [2.0898699295E-3, 1.9755100366E-3, -8.3368999185E-4, -5.1014497876E-3, 7.3032598011E-3, -9.3812197447E-3, -0.1554141641,\n                0.7303866148, -0.1554141641, -9.3812197447E-3, 7.3032598011E-3, -5.1014497876E-3, -8.3368999185E-4, 1.9755100366E-3, 2.0898699295E-3],\n            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -\n                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],\n            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -\n                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],\n            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,\n                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],\n            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -\n                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],\n            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -\n                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],\n            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,\n                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],\n            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3, 2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4]]\n        ).reshape(1, 1, 15, 15)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor(\n            [-8.1125000725E-4, 4.4451598078E-3, 1.2316980399E-2, 1.3955879956E-2,  1.4179450460E-2, 1.3955879956E-2, 1.2316980399E-2, 4.4451598078E-3, -8.1125000725E-4,\n             3.9103501476E-3, 4.4565401040E-3, -5.8724298142E-3, -2.8760801069E-3, 8.5267601535E-3, -\n             2.8760801069E-3, -5.8724298142E-3, 4.4565401040E-3, 3.9103501476E-3,\n             1.3462699717E-3, -3.7740699481E-3, 8.2581602037E-3, 3.9442278445E-2, 5.3605638444E-2, 3.9442278445E-2, 8.2581602037E-3, -\n             3.7740699481E-3, 1.3462699717E-3,\n             7.4700999539E-4, -3.6522001028E-4, -2.2522680461E-2, -0.1105690673, -\n             0.1768419296, -0.1105690673, -2.2522680461E-2, -3.6522001028E-4, 7.4700999539E-4,\n             0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,\n             -7.4700999539E-4, 3.6522001028E-4, 2.2522680461E-2, 0.1105690673, 0.1768419296, 0.1105690673, 2.2522680461E-2, 3.6522001028E-4, -7.4700999539E-4,\n             -1.3462699717E-3, 3.7740699481E-3, -8.2581602037E-3, -3.9442278445E-2, -\n             5.3605638444E-2, -3.9442278445E-2, -\n             8.2581602037E-3, 3.7740699481E-3, -1.3462699717E-3,\n             -3.9103501476E-3, -4.4565401040E-3, 5.8724298142E-3, 2.8760801069E-3, -\n             8.5267601535E-3, 2.8760801069E-3, 5.8724298142E-3, -\n             4.4565401040E-3, -3.9103501476E-3,\n             8.1125000725E-4, -4.4451598078E-3, -1.2316980399E-2, -1.3955879956E-2, -1.4179450460E-2, -1.3955879956E-2, -1.2316980399E-2, -4.4451598078E-3, 8.1125000725E-4]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3,\n             8.2846998703E-4, 0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -\n             2.0865600090E-3, 2.3298799060E-3, -\n             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,\n             5.7109999034E-5, 9.7479997203E-4, 0.0000000000, -1.2145539746E-2, -\n             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -\n             4.4814897701E-3, 1.4807609841E-2,\n             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -\n             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,\n             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -\n             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,\n             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, 0.0000000000, -\n             1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,\n             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -\n             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -\n             9.7479997203E-4, -5.7109999034E-5,\n             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -\n             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, 0.0000000000, -8.2846998703E-4,\n             3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4,\n             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -\n             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,\n             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -\n             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,\n             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -\n             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,\n             -1.4179450460E-2, -8.5267601535E-3, -5.3605638444E-2, 0.1768419296, 0.0000000000, -\n             0.1768419296, 5.3605638444E-2, 8.5267601535E-3, 1.4179450460E-2,\n             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -\n             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,\n             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -\n             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,\n             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -\n             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,\n             8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000,\n             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -\n             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, -\n             0.0000000000, -8.2846998703E-4,\n             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -\n             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -\n             9.7479997203E-4, -5.7109999034E-5,\n             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, -\n             0.0000000000, -1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,\n             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -\n             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,\n             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -\n             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,\n             5.7109999034E-5, 9.7479997203E-4, -0.0000000000, -1.2145539746E-2, -\n             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -\n             4.4814897701E-3, 1.4807609841E-2,\n             8.2846998703E-4, -0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -\n             2.0865600090E-3, 2.3298799060E-3, -\n             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,\n             0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 6:\n        filters[\"l\"] = 2 * torch.tensor([\n            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -\n                0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404],\n            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,\n                0.00410600, -0.00661117, -0.00523281, -0.00244917],\n            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,\n                0.03277038, 0.01396746, -0.00661117, -0.00387812],\n            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,\n                0.06426333, 0.03277038, 0.00410600, -0.00944432],\n            [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 0.10096540,\n                0.08169618, 0.03981393, 0.01002988, -0.00962054],\n            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,\n                0.06426333, 0.03277038, 0.00410600, -0.00944432],\n            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,\n                0.03277038, 0.01396746, -0.00661117, -0.00387812],\n            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,\n                0.00410600, -0.00661117, -0.00523281, -0.00244917],\n            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"l0\"] = torch.tensor([\n            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],\n            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],\n            [-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],\n            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],\n            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614]]\n        ).reshape(1, 1, 5, 5)\n        filters[\"h0\"] = torch.tensor([\n            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -\n                0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429],\n            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,\n                0.00631653, -0.00243812, -0.00350017, -0.00113093],\n            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -\n                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],\n            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -\n                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],\n            [-0.00080639, 0.01261227, -0.00981051, -0.11435863,\n                0.81380200, -0.11435863, -0.00981051, 0.01261227, -0.00080639],\n            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -\n                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],\n            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -\n                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],\n            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,\n                0.00631653, -0.00243812, -0.00350017, -0.00113093],\n            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor([\n            0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 0.00496194, 0.00277643,\n            -0.00986904, -0.00893064, 0.01189859, 0.02755155, 0.01189859, -0.00893064, -0.00986904,\n            -0.01021852, -0.03075356, -0.08226445, -\n            0.11732297, -0.08226445, -0.03075356, -0.01021852,\n            0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,\n            0.01021852, 0.03075356, 0.08226445, 0.11732297, 0.08226445, 0.03075356, 0.01021852,\n            0.00986904, 0.00893064, -0.01189859, -\n            0.02755155, -0.01189859, 0.00893064, 0.00986904,\n            -0.00277643, -0.00496194, -0.01026699, -0.01455399, -0.01026699, -0.00496194, -0.00277643]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982,\n            -0.00358461, -0.01977507, -0.04084211, -\n            0.00228219, 0.03930573, 0.01161195, 0.00128000,\n            0.01047717, 0.01486305, -0.04819057, -\n            0.12227230, -0.05394139, 0.00853965, -0.00459034,\n            0.00790407, 0.04435647, 0.09454202, -0.00000000, -\n            0.09454202, -0.04435647, -0.00790407,\n            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,\n            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,\n            -0.01166982, -0.00285723, -0.00182078, -0.01124321, 0.00073141, 0.00640815, 0.00343249]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            0.00343249, 0.00358461, -0.01047717, -\n            0.00790407, -0.00459034, 0.00128000, 0.01166982,\n            0.00640815, 0.01977507, -0.01486305, -\n            0.04435647, 0.00853965, 0.01161195, 0.00285723,\n            0.00073141, 0.04084211, 0.04819057, -\n            0.09454202, -0.05394139, 0.03930573, 0.00182078,\n            -0.01124321, 0.00228219, 0.12227230, -\n            0.00000000, -0.12227230, -0.00228219, 0.01124321,\n            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -\n            0.04819057, -0.04084211, -0.00073141,\n            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,\n            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [-0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643,\n             -0.00496194, 0.00893064, 0.03075356, -\n             0.00000000, -0.03075356, -0.00893064, 0.00496194,\n             -0.01026699, -0.01189859, 0.08226445, -\n             0.00000000, -0.08226445, 0.01189859, 0.01026699,\n             -0.01455399, -0.02755155, 0.11732297, -\n             0.00000000, -0.11732297, 0.02755155, 0.01455399,\n             -0.01026699, -0.01189859, 0.08226445, -\n             0.00000000, -0.08226445, 0.01189859, 0.01026699,\n             -0.00496194, 0.00893064, 0.03075356, -\n             0.00000000, -0.03075356, -0.00893064, 0.00496194,\n             -0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249,\n            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,\n            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -\n            0.04819057, -0.04084211, -0.00073141,\n            -0.01124321, 0.00228219, 0.12227230, -\n            0.00000000, -0.12227230, -0.00228219, 0.01124321,\n            0.00073141, 0.04084211, 0.04819057, -\n            0.09454202, -0.05394139, 0.03930573, 0.00182078,\n            0.00640815, 0.01977507, -0.01486305, -\n            0.04435647, 0.00853965, 0.01161195, 0.00285723,\n            0.00343249, 0.00358461, -0.01047717, -0.00790407, -0.00459034, 0.00128000, 0.01166982]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.01166982, -0.00285723, -0.00182078, -\n            0.01124321, 0.00073141, 0.00640815, 0.00343249,\n            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,\n            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,\n            0.00790407, 0.04435647, 0.09454202, -0.00000000, -\n            0.09454202, -0.04435647, -0.00790407,\n            0.01047717, 0.01486305, -0.04819057, -\n            0.12227230, -0.05394139, 0.00853965, -0.00459034,\n            -0.00358461, -0.01977507, -0.04084211, -\n            0.00228219, 0.03930573, 0.01161195, 0.00128000,\n            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n\n    else:\n        raise Exception(\n            \"Steerable filters not implemented for %d orientations\" % n_orientations)\n\n    if filter_type == \"trained\":\n        if size == 5:\n            # TODO maybe also train h0 and l0 filters\n            filters = crop_steerable_pyramid_filters(filters, 5)\n            filters[\"b\"][0] = torch.tensor([\n                [-0.0356752239, -0.0223877281, -0.0009542659,\n                    0.0244821459, 0.0322226137],\n                [-0.0593218654,  0.1245803162, -\n                    0.0023863907, -0.1230178699, 0.0589442067],\n                [-0.0281576272,  0.2976626456, -\n                    0.0020888755, -0.2953369915, 0.0284542721],\n                [-0.0586092323,  0.1251581162, -\n                    0.0024624448, -0.1227868199, 0.0587830991],\n                [-0.0327464789, -0.0223652460, -\n                    0.0042342511,  0.0245472137, 0.0359398536]\n            ]).reshape(1, 1, 5, 5)\n            filters[\"b\"][1] = torch.tensor([\n                [3.9758663625e-02,  6.0679119080e-02,  3.0146904290e-02,\n                    6.1198268086e-02,  3.6218870431e-02],\n                [2.3255519569e-02, -1.2505133450e-01, -\n                    2.9738345742e-01, -1.2518258393e-01,  2.3592948914e-02],\n                [-1.3602430699e-03, -1.2058277935e-04,  2.6399988565e-04, -\n                    2.3791544663e-04,  1.8450465286e-03],\n                [-2.1563466638e-02,  1.2572696805e-01,  2.9745018482e-01,\n                    1.2458638102e-01, -2.3847281933e-02],\n                [-3.7941932678e-02, -6.1060950160e-02, -\n                    2.9489086941e-02, -6.0411967337e-02, -3.8459088653e-02]\n            ]).reshape(1, 1, 5, 5)\n\n            # Below filters were optimised on 09/02/2021\n            # 20K iterations with multiple images at more scales.\n            filters[\"b\"][0] = torch.tensor([\n                [-4.5508436859e-02, -2.1767273545e-02, -1.9399923622e-04,\n                    2.1200872958e-02,  4.5475799590e-02],\n                [-6.3554823399e-02,  1.2832683325e-01, -\n                    5.3858719184e-05, -1.2809979916e-01,  6.3842624426e-02],\n                [-3.4809380770e-02,  2.9954621196e-01,  2.9066693969e-05, -\n                    2.9957753420e-01,  3.4806568176e-02],\n                [-6.3934154809e-02,  1.2806062400e-01,  9.0917674243e-05, -\n                    1.2832444906e-01,  6.3572973013e-02],\n                [-4.5492250472e-02, -2.1125273779e-02,  4.2229349492e-04,\n                    2.1804777905e-02,  4.5236673206e-02]\n            ]).reshape(1, 1, 5, 5)\n            filters[\"b\"][1] = torch.tensor([\n                [4.8947390169e-02,  6.3575074077e-02,  3.4955859184e-02,\n                    6.4085893333e-02,  4.9838040024e-02],\n                [2.2061849013e-02, -1.2936264277e-01, -\n                    3.0093491077e-01, -1.2997294962e-01,  2.0597217605e-02],\n                [-5.1290717238e-05, -1.7305796064e-05,  2.0256420612e-05, -\n                    1.1864109547e-04,  7.3973249528e-05],\n                [-2.0749464631e-02,  1.2988376617e-01,  3.0080935359e-01,\n                    1.2921217084e-01, -2.2159902379e-02],\n                [-4.9614857882e-02, -6.4021714032e-02, -\n                    3.4676689655e-02, -6.3446544111e-02, -4.8282280564e-02]\n            ]).reshape(1, 1, 5, 5)\n\n            # Trained on 17/02/2021 to match fourier pyramid in spatial domain\n            filters[\"b\"][0] = torch.tensor([\n                [3.3370e-02,  9.3934e-02, -3.5810e-04, -9.4038e-02, -3.3115e-02],\n                [1.7716e-01,  3.9378e-01,  6.8461e-05, -3.9343e-01, -1.7685e-01],\n                [2.9213e-01,  6.1042e-01,  7.0654e-04, -6.0939e-01, -2.9177e-01],\n                [1.7684e-01,  3.9392e-01,  1.0517e-03, -3.9268e-01, -1.7668e-01],\n                [3.3000e-02,  9.4029e-02,  7.3565e-04, -9.3366e-02, -3.3008e-02]\n            ]).reshape(1, 1, 5, 5) * 0.1\n\n            filters[\"b\"][1] = torch.tensor([\n                [0.0331,  0.1763,  0.2907,  0.1753,  0.0325],\n                [0.0941,  0.3932,  0.6079,  0.3904,  0.0922],\n                [0.0008,  0.0009, -0.0010, -0.0025, -0.0015],\n                [-0.0929, -0.3919, -0.6097, -0.3944, -0.0946],\n                [-0.0328, -0.1760, -0.2915, -0.1768, -0.0333]\n            ]).reshape(1, 1, 5, 5) * 0.1\n\n        else:\n            raise Exception(\n                \"Trained filters not implemented for size %d\" % size)\n\n    if filter_type == \"cropped\":\n        filters = crop_steerable_pyramid_filters(filters, size)\n\n    return filters\n
"},{"location":"odak/learn_perception/#odak.learn.perception.hsv_to_rgb","title":"hsv_to_rgb(image)","text":"

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image \u2013
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_rgb ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def hsv_to_rgb(image):\n\n    \"\"\"\n    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_rgb       : torch.tensor\n                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    h = image[..., 0, :, :] / (2 * math.pi)\n    s = image[..., 1, :, :]\n    v = image[..., 2, :, :]\n    hi = torch.floor(h * 6) % 6\n    f = ((h * 6) % 6) - hi\n    one = torch.tensor(1.0)\n    p = v * (one - s)\n    q = v * (one - f * s)\n    t = v * (one - (one - f) * s)\n    hi = hi.long()\n    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)\n    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)\n    image_rgb = torch.gather(image_rgb, -3, indices)\n    return image_rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.lab_to_srgb","title":"lab_to_srgb(image)","text":"

Definition to convert LAB space to SRGB color space.

Parameters:

  • image \u2013
              Input image in LAB color space[3 x m x n]\n

Returns:

  • image_srgb ( tensor ) \u2013

    Output image in SRGB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def lab_to_srgb(image):\n    \"\"\"\n    Definition to convert LAB space to SRGB color space. \n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in LAB color space[3 x m x n]\n    Returns\n    -------\n    image_srgb     : torch.tensor\n                      Output image in SRGB color space [3 x m x n].\n    \"\"\"\n\n    if image.shape[-1] == 3:\n        input_color = image.permute(2, 0, 1)  # C(H*W)\n    else:\n        input_color = image\n    # lab ---> xyz\n    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)\n    y = (input_color[0:1, :, :] + 16) / 116\n    a =  input_color[1:2, :, :] / 500\n    b =  input_color[2:3, :, :] / 200\n    x = y + a\n    z = y - b\n    xyz = torch.cat((x, y, z), 0)\n    delta = 6 / 29\n    factor = 3 * delta * delta\n    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))\n    xyz_color = xyz * reference_illuminant\n    # xyz ---> linear rgb\n    a11 = 3.241003275\n    a12 = -1.537398934\n    a13 = -0.498615861\n    a21 = -0.969224334\n    a22 = 1.875930071\n    a23 = 0.041554224\n    a31 = 0.055639423\n    a32 = -0.204011202\n    a33 = 1.057148933\n    A = torch.tensor([[a11, a12, a13],\n                  [a21, a22, a23],\n                  [a31, a32, a33]], dtype=torch.float32)\n\n    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)\n    linear_rgb_color = torch.matmul(A, xyz_color)\n    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)\n    # linear rgb ---> srgb\n    limit = 0.0031308\n    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)\n    return image_srgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.linear_rgb_to_rgb","title":"linear_rgb_to_rgb(image, threshold=0.0031308)","text":"

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image \u2013
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n
  • threshold \u2013
              Threshold used in calculations.\n

Returns:

  • image_linear ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_rgb(image, threshold = 0.0031308):\n    \"\"\"\n    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n    threshold       : float\n                      Threshold used in calculations.\n\n    Returns\n    -------\n    image_linear    : torch.tensor\n                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)\n    return image_linear\n
"},{"location":"odak/learn_perception/#odak.learn.perception.linear_rgb_to_xyz","title":"linear_rgb_to_xyz(image)","text":"

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image \u2013
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_xyz ( tensor ) \u2013

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_xyz(image):\n    \"\"\"\n    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_xyz       : torch.tensor\n                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    a11 = 0.412453\n    a12 = 0.357580\n    a13 = 0.180423\n    a21 = 0.212671\n    a22 = 0.715160\n    a23 = 0.072169\n    a31 = 0.019334\n    a32 = 0.119193\n    a33 = 0.950227\n    M = torch.tensor([[a11, a12, a13], \n                      [a21, a22, a23],\n                      [a31, a32, a33]])\n    size = image.size()\n    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)\n    image_xyz = torch.matmul(M, image)\n    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])\n    return image_xyz\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_3d_location_map","title":"make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)","text":"

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction.

Parameters:

  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • map ( tensor ) \u2013

    The computed 3D location map, of size 3xWxH.

Source code in odak/learn/perception/foveation.py
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):\n    \"\"\" \n    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to\n    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction.\n\n    Parameters\n    ----------\n\n    image_pixel_size        : tuple of ints \n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float \n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    map                     : torch.tensor\n                                The computed 3D location map, of size 3xWxH.\n    \"\"\"\n    real_image_height = (real_image_width /\n                         image_pixel_size[-1]) * image_pixel_size[-2]\n    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width\n    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)\n    y_coords = torch.linspace(-0.5, 0.5,\n                              image_pixel_size[-2])*real_image_height\n    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])\n    z_coords = torch.ones(\n        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance\n\n    return torch.cat([x_coords, y_coords, z_coords], dim=0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_eccentricity_distance_maps","title":"make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)","text":"

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output in radians.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • eccentricity_map ( tensor ) \u2013

    The computed eccentricity map, of size WxH.

  • distance_map ( tensor ) \u2013

    The computed distance map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):\n    \"\"\" \n    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to\n    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction. Output in radians.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    eccentricity_map        : torch.tensor\n                                The computed eccentricity map, of size WxH.\n    distance_map            : torch.tensor\n                                The computed distance map, of size WxH.\n    \"\"\"\n    real_image_height = (real_image_width /\n                         image_pixel_size[-1]) * image_pixel_size[-2]\n    location_map = make_3d_location_map(\n        image_pixel_size, real_image_width, real_viewing_distance)\n    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))\n    direction_map = location_map / distance_map\n\n    gaze_location_3d = torch.tensor([\n        (gaze_location[0]*2 - 1)*real_image_width*0.5,\n        (gaze_location[1]*2 - 1)*real_image_height*0.5,\n        real_viewing_distance])\n    gaze_dir = gaze_location_3d / \\\n        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))\n    gaze_dir = gaze_dir[:, None, None]\n\n    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)\n    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)\n    eccentricity_map = torch.acos(dot_prod_map)\n\n    return eccentricity_map, distance_map\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_equi_pooling_size_map_lod","title":"make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')","text":"

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_angles \u2013
                    Gaze direction expressed as angles, in radians.\n
  • image_pixel_size \u2013
                    Dimensions of the image in pixels, as a tuple of (height, width)\n
  • alpha \u2013
                    Parameter controlling extent of foveation\n
  • mode \u2013
                    Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size HxW.

Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode=\"quadratic\"):\n    \"\"\" \n    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from\n    to achieve the correct pooling region areas.\n\n    Parameters\n    ----------\n\n    gaze_angles         : tuple of 2 floats\n                            Gaze direction expressed as angles, in radians.\n    image_pixel_size    : tuple of 2 ints\n                            Dimensions of the image in pixels, as a tuple of (height, width)\n    alpha               : float\n                            Parameter controlling extent of foveation\n    mode                : str\n                            Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size HxW.\n    \"\"\"\n    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)\n    import matplotlib.pyplot as plt\n    pooling_lod = torch.log2(1e-6+pooling_pixel)\n    pooling_lod[pooling_lod < 0] = 0\n    return pooling_lod\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_equi_pooling_size_map_pixels","title":"make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')","text":"

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image corresponds to pitch, ranging from -pi/2 to pi/2.

In this setup real_image_width and real_viewing_distance have no effect.

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

Parameters:

  • gaze_angles \u2013
                    Gaze direction expressed as angles, in radians.\n
  • image_pixel_size \u2013
                    Dimensions of the image in pixels, as a tuple of (height, width)\n
  • alpha \u2013
                    Parameter controlling extent of foveation\n
  • mode \u2013
                    Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n
Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode=\"quadratic\"):\n    \"\"\"\n    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.\n    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, \n    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image\n    corresponds to pitch, ranging from -pi/2 to pi/2.\n\n    In this setup real_image_width and real_viewing_distance have no effect.\n\n    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).\n\n    Parameters\n    ----------\n\n    gaze_angles         : tuple of 2 floats\n                            Gaze direction expressed as angles, in radians.\n    image_pixel_size    : tuple of 2 ints\n                            Dimensions of the image in pixels, as a tuple of (height, width)\n    alpha               : float\n                            Parameter controlling extent of foveation\n    mode                : str\n                            Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n    \"\"\"\n    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])\n\n    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])\n    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]\n    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])\n    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]\n\n    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])\n\n    # Work out the pooling region diameter in radians\n    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)\n    eccentricity = torch.acos(view_dot_dir)\n    pooling_rad = alpha * eccentricity\n    if mode == \"quadratic\":\n        pooling_rad *= eccentricity\n\n    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes\n    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse\n    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)\n    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)\n    pooling_axis_x = pooling_rad * pixels_per_rad_x\n    pooling_axis_y = pooling_rad * pixels_per_rad_y\n    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25\n\n    # Now finally find the length of the side of a square of the same area.\n    size = torch.sqrt(torch.abs(area))\n    return size\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_pooling_size_map_lod","title":"make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')","text":"

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode=\"quadratic\"):\n    \"\"\" \n    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from\n    to achieve the correct pooling region areas.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size WxH.\n    \"\"\"\n    pooling_pixel = make_pooling_size_map_pixels(\n        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)\n    pooling_lod = torch.log2(1e-6+pooling_pixel)\n    pooling_lod[pooling_lod < 0] = 0\n    return pooling_lod\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_pooling_size_map_pixels","title":"make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')","text":"

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity (also in radians).

Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode=\"quadratic\"):\n    \"\"\" \n    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to\n    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity\n    (also in radians). \n\n    Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size WxH.\n    \"\"\"\n    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(\n        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)\n    eccentricity_centre, _ = make_eccentricity_distance_maps(\n        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)\n    pooling_rad = alpha * eccentricity\n    if mode == \"quadratic\":\n        pooling_rad *= eccentricity\n    angle_min = eccentricity_centre - pooling_rad*0.5\n    angle_max = eccentricity_centre + pooling_rad*0.5\n    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \\\n        real_viewing_distance\n    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)\n    area = math.pi * major_axis * minor_axis * 0.25\n    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number\n    area = torch.abs(area)\n    pooling_real = torch.sqrt(area)\n    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]\n    return pooling_pixel\n
"},{"location":"odak/learn_perception/#odak.learn.perception.make_radial_map","title":"make_radial_map(size, gaze)","text":"

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

Parameters:

  • size \u2013
        Dimensions of the image\n
  • gaze \u2013
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n    image coordinates (ranging from 0 to 1)\n
Source code in odak/learn/perception/foveation.py
def make_radial_map(size, gaze):\n    \"\"\" \n    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.\n\n    Parameters\n    ----------\n\n    size    : tuple of ints\n                Dimensions of the image\n    gaze    : tuple of floats\n                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                image coordinates (ranging from 0 to 1)\n    \"\"\"\n    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]\n    rows = torch.linspace(0, size[0], size[0])\n    rows = rows[:, None].repeat(1, size[1])\n    cols = torch.linspace(0, size[1], size[1])\n    cols = cols[None, :].repeat(size[0], 1)\n    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \\\n        torch.pow(cols - pix_gaze[1], 2)\n    radii = torch.sqrt(dist_sq)\n    return radii/torch.max(radii)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.pad_image_for_pyramid","title":"pad_image_for_pyramid(image, n_pyramid_levels)","text":"

Pads an image to the extent necessary to compute a steerable pyramid of the input image. This involves padding so both height and width are divisible by 2**n_pyramid_levels. Uses reflection padding.

Parameters:

  • image \u2013

    Image to pad, in NCHW format

  • n_pyramid_levels \u2013

    Number of levels in the pyramid you plan to construct.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def pad_image_for_pyramid(image, n_pyramid_levels):\n    \"\"\"\n    Pads an image to the extent necessary to compute a steerable pyramid of the input image.\n    This involves padding so both height and width are divisible by 2**n_pyramid_levels.\n    Uses reflection padding.\n\n    Parameters\n    ----------\n\n    image: torch.tensor\n        Image to pad, in NCHW format\n    n_pyramid_levels: int\n        Number of levels in the pyramid you plan to construct.\n    \"\"\"\n    min_divisor = 2 ** n_pyramid_levels\n    height = image.size(2)\n    width = image.size(3)\n    required_height = math.ceil(height / min_divisor) * min_divisor\n    required_width = math.ceil(width / min_divisor) * min_divisor\n    if required_height > height or required_width > width:\n        # We need to pad!\n        pad = torch.nn.ReflectionPad2d(\n            (0, 0, required_height-height, required_width-width))\n        return pad(image)\n    return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.rgb_2_ycrcb","title":"rgb_2_ycrcb(image)","text":"

Converts an image from RGB colourspace to YCrCb colourspace.

Parameters:

  • image \u2013
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n

Returns:

  • ycrcb ( tensor ) \u2013

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_2_ycrcb(image):\n    \"\"\"\n    Converts an image from RGB colourspace to YCrCb colourspace.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n\n    Returns\n    -------\n\n    ycrcb   : torch.tensor\n              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n       image = image.unsqueeze(0)\n    ycrcb = torch.zeros(image.size()).to(image.device)\n    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \\\n        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]\n    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])\n    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])\n    return ycrcb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.rgb_to_hsv","title":"rgb_to_hsv(image, eps=1e-08)","text":"

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image \u2013
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_hsv ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_hsv(image, eps: float = 1e-8):\n\n    \"\"\"\n    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_hsv       : torch.tensor\n                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    max_rgb, argmax_rgb = image.max(-3)\n    min_rgb, argmin_rgb = image.min(-3)\n    deltac = max_rgb - min_rgb\n    v = max_rgb\n    s = deltac / (max_rgb + eps)\n    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)\n    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)\n    h1 = bc - gc\n    h2 = (rc - bc) + 2.0 * deltac\n    h3 = (gc - rc) + 4.0 * deltac\n    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)\n    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)\n    h = (h / 6.0) % 1.0\n    h = 2.0 * math.pi * h \n    image_hsv = torch.stack((h, s, v), dim=-3)\n    return image_hsv\n
"},{"location":"odak/learn_perception/#odak.learn.perception.rgb_to_linear_rgb","title":"rgb_to_linear_rgb(image, threshold=0.0031308)","text":"

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image \u2013
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n
  • threshold \u2013
              Threshold used in calculations.\n

Returns:

  • image_linear ( tensor ) \u2013

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_linear_rgb(image, threshold = 0.0031308):\n    \"\"\"\n    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n    threshold       : float\n                      Threshold used in calculations.\n\n    Returns\n    -------\n    image_linear    : torch.tensor\n                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)\n    return image_linear\n
"},{"location":"odak/learn_perception/#odak.learn.perception.srgb_to_lab","title":"srgb_to_lab(image)","text":"

Definition to convert SRGB space to LAB color space.

Parameters:

  • image \u2013
              Input image in SRGB color space[3 x m x n]\n

Returns:

  • image_lab ( tensor ) \u2013

    Output image in LAB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def srgb_to_lab(image):    \n    \"\"\"\n    Definition to convert SRGB space to LAB color space. \n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in SRGB color space[3 x m x n]\n    Returns\n    -------\n    image_lab       : torch.tensor\n                      Output image in LAB color space [3 x m x n].\n    \"\"\"\n    if image.shape[-1] == 3:\n        input_color = image.permute(2, 0, 1)  # C(H*W)\n    else:\n        input_color = image\n    # rgb ---> linear rgb\n    limit = 0.04045        \n    # linear rgb ---> xyz\n    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)\n\n    a11 = 10135552 / 24577794\n    a12 = 8788810  / 24577794\n    a13 = 4435075  / 24577794\n    a21 = 2613072  / 12288897\n    a22 = 8788810  / 12288897\n    a23 = 887015   / 12288897\n    a31 = 1425312  / 73733382\n    a32 = 8788810  / 73733382\n    a33 = 70074185 / 73733382\n\n    A = torch.tensor([[a11, a12, a13],\n                    [a21, a22, a23],\n                    [a31, a32, a33]], dtype=torch.float32)\n\n    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)\n    xyz_color = torch.matmul(A, linrgb_color)\n    xyz_color = xyz_color.permute(1, 2, 0)\n    # xyz ---> lab\n    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)\n    input_color = xyz_color * inv_reference_illuminant\n    delta = 6 / 29\n    delta_square = delta * delta\n    delta_cube = delta * delta_square\n    factor = 1 / (3 * delta_square)\n\n    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))\n\n    l = 116 * input_color[1:2, :, :] - 16\n    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])\n    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])\n\n    image_lab = torch.cat((l, a, b), 0)\n    return image_lab    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.xyz_to_linear_rgb","title":"xyz_to_linear_rgb(image)","text":"

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image \u2013
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_linear_rgb ( tensor ) \u2013

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def xyz_to_linear_rgb(image):\n    \"\"\"\n    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)\n\n    Parameters\n    ----------\n    image            : torch.tensor\n                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_linear_rgb : torch.tensor\n                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    a11 = 3.240479\n    a12 = -1.537150\n    a13 = -0.498535\n    a21 = -0.969256 \n    a22 = 1.875992 \n    a23 = 0.041556\n    a31 = 0.055648\n    a32 = -0.204043\n    a33 = 1.057311\n    M = torch.tensor([[a11, a12, a13], \n                      [a21, a22, a23],\n                      [a31, a32, a33]])\n    size = image.size()\n    image = image.reshape(size[0], size[1], size[2]*size[3])\n    image_linear_rgb = torch.matmul(M, image)\n    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])\n    return image_linear_rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.ycrcb_2_rgb","title":"ycrcb_2_rgb(image)","text":"

Converts an image from YCrCb colourspace to RGB colourspace.

Parameters:

  • image \u2013
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n

Returns:

  • rgb ( tensor ) \u2013

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def ycrcb_2_rgb(image):\n    \"\"\"\n    Converts an image from YCrCb colourspace to RGB colourspace.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n\n    Returns\n    -------\n    rgb     : torch.tensor\n              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n       image = image.unsqueeze(0)\n    rgb = torch.zeros(image.size(), device=image.device)\n    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)\n    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \\\n        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)\n    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)\n    return rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.blur_loss.BlurLoss","title":"BlurLoss","text":"

BlurLoss implements two different blur losses. When blur_source is set to False, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.

When blur_source is set to True, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.

The interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/blur_loss.py
class BlurLoss():\n    \"\"\" \n\n    `BlurLoss` implements two different blur losses. When `blur_source` is set to `False`, it implements blur_match, trying to match the input image to the blurred target image. This tries to match the source input image to a blurred version of the target.\n\n    When `blur_source` is set to `True`, it implements blur_lowpass, matching the blurred version of the input image to the blurred target image. This tries to match only the low frequencies of the source input image to the low frequencies of the target.\n\n    The interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device(\"cpu\"),\n                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\", blur_source=False, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        blur_source             : bool\n                                    If true, blurs the source image as well as the target before computing the loss.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.mode = mode\n        self.blur = None\n        self.loss_func = torch.nn.MSELoss()\n        self.blur_source = blur_source\n        self.equi = equi\n\n    def blur_image(self, image, gaze):\n        if self.blur is None:\n            self.blur = RadiallyVaryingBlur()\n        return self.blur.blur(image, self.alpha, self.real_image_width, self.real_viewing_distance, gaze, self.mode, self.equi)\n\n    def __call__(self, image, target, gaze=[0.5, 0.5]):\n        \"\"\" \n        Calculates the Blur Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze                : list\n                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"BlurLoss\", image, target)\n        blurred_target = self.blur_image(target, gaze)\n        if self.blur_source:\n            blurred_image = self.blur_image(image, gaze)\n            return self.loss_func(blurred_image, blurred_target)\n        else:\n            return self.loss_func(image, blurred_target)\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.blur_loss.BlurLoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5])","text":"

Calculates the Blur Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/blur_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):\n    \"\"\" \n    Calculates the Blur Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze                : list\n                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"BlurLoss\", image, target)\n    blurred_target = self.blur_image(target, gaze)\n    if self.blur_source:\n        blurred_image = self.blur_image(image, gaze)\n        return self.loss_func(blurred_image, blurred_target)\n    else:\n        return self.loss_func(image, blurred_target)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.blur_loss.BlurLoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', blur_source=False, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • blur_source \u2013
                        If true, blurs the source image as well as the target before computing the loss.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/blur_loss.py
def __init__(self, device=torch.device(\"cpu\"),\n             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\", blur_source=False, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    blur_source             : bool\n                                If true, blurs the source image as well as the target before computing the loss.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.alpha = alpha\n    self.real_image_width = real_image_width\n    self.real_viewing_distance = real_viewing_distance\n    self.mode = mode\n    self.blur = None\n    self.loss_func = torch.nn.MSELoss()\n    self.blur_source = blur_source\n    self.equi = equi\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs","title":"display_color_hvs","text":"Source code in odak/learn/perception/color_conversion.py
class display_color_hvs():\n\n    def __init__(\n                 self,\n                 resolution = [1920, 1080],\n                 distance_from_screen = 800,\n                 pixel_pitch = 0.311,\n                 read_spectrum = 'tensor',\n                 primaries_spectrum = torch.rand(3, 301),\n                 device = torch.device('cpu')):\n        '''\n        Parameters\n        ----------\n        resolution                  : list\n                                      Resolution of the display in pixels.\n        distance_from_screen        : int\n                                      Distance from the screen in mm.\n        pixel_pitch                 : float\n                                      Pixel pitch of the display in mm.\n        read_spectrum               : str\n                                      Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n        device                      : torch.device\n                                      Device to run the code on. Default is None which means the code will run on CPU.\n\n        '''\n        self.device = device\n        self.read_spectrum = read_spectrum\n        self.primaries_spectrum = primaries_spectrum.to(self.device)\n        self.resolution = resolution\n        self.distance_from_screen = distance_from_screen\n        self.pixel_pitch = pixel_pitch\n        self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()\n        self.lms_tensor = self.construct_matrix_lms(\n                                                    self.l_normalized,\n                                                    self.m_normalized,\n                                                    self.s_normalized\n                                                   )   \n        self.primaries_tensor = self.construct_matrix_primaries(\n                                                                self.l_normalized,\n                                                                self.m_normalized,\n                                                                self.s_normalized\n                                                               )   \n        return\n\n\n    def __call__(self, input_image, ground_truth, gaze=None):\n        \"\"\"\n        Evaluating an input image against a target ground truth image for a given gaze of a viewer.\n        \"\"\"\n        lms_image_second = self.primaries_to_lms(input_image.to(self.device))\n        lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))\n        lms_image_third = self.second_to_third_stage(lms_image_second)\n        lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)\n        loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)\n        return loss_metamer_color\n\n\n    def initialize_cones_normalized(self):\n        \"\"\"\n        Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. \n\n        Returns\n        -------\n        l_cone_n                     : torch.tensor\n                                       Normalised L cone distribution.\n        m_cone_n                     : torch.tensor\n                                       Normalised M cone distribution.\n        s_cone_n                     : torch.tensor\n                                       Normalised S cone distribution.\n        \"\"\"\n        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n        dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))\n        dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))\n        dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))\n\n        l_cone_n = dist_l / dist_l.max()\n        m_cone_n = dist_m / dist_m.max()\n        s_cone_n = dist_s / dist_s.max()\n        return l_cone_n, m_cone_n, s_cone_n\n\n\n    def initialize_rgb_backlight_spectrum(self):\n        \"\"\"\n        Internal function to initialize baclight spectrum for color primaries. \n\n        Returns\n        -------\n        red_spectrum                 : torch.tensor\n                                       Normalised backlight spectrum for red color primary.\n        green_spectrum               : torch.tensor\n                                       Normalised backlight spectrum for green color primary.\n        blue_spectrum                : torch.tensor\n                                       Normalised backlight spectrum for blue color primary.\n        \"\"\"\n        wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n        red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))\n        green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))\n        blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))\n\n        red_spectrum = red_spectrum / red_spectrum.max()\n        green_spectrum = green_spectrum / green_spectrum.max()\n        blue_spectrum = blue_spectrum / blue_spectrum.max()\n\n        return red_spectrum, green_spectrum, blue_spectrum\n\n\n    def initialize_random_spectrum_normalized(self, dataset):\n        \"\"\"\n        Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. \n\n        Parameters\n        ----------\n        dataset                                : torch.tensor \n                                                 spectrum value against wavelength \n        \"\"\"\n        dataset = torch.swapaxes(dataset, 0, 1)\n        x_spectrum = torch.linspace(400, 700, steps = 301) - 550\n        y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))\n        max_spectrum = torch.max(y_spectrum)\n        y_spectrum /= max_spectrum\n\n        def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \\\n            torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))\n\n        def function(x, weights): \n            return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])\n\n        weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)\n        optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)\n\n        def closure():\n            optimizer.zero_grad()\n            output = function(x_spectrum, weights)\n            loss = F.mse_loss(output, y_spectrum)\n            loss.backward()\n            return loss\n        optimizer.step(closure)\n        spectrum = function(x_spectrum, weights)\n        return spectrum.detach().to(self.device)\n\n\n    def display_spectrum_response(wavelength, function):\n        \"\"\"\n        Internal function to provide light spectrum response at particular wavelength\n\n        Parameters\n        ----------\n        wavelength                          : torch.tensor\n                                              Wavelength in nm [400...700]\n        function                            : torch.tensor\n                                              Display light spectrum distribution function\n\n        Returns\n        -------\n        ligth_response_dict                  : float\n                                               Display light spectrum response value\n        \"\"\"\n        wavelength = int(round(wavelength, 0))\n        if wavelength >= 400 and wavelength <= 700:\n            return function[wavelength - 400].item()\n        elif wavelength < 400:\n            return function[0].item()\n        else:\n            return function[300].item()\n\n\n    def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):\n        \"\"\"\n        Internal function to calculate cone response at particular light spectrum. \n\n        Parameters\n        ----------\n        cone_spectrum                         : torch.tensor\n                                                Spectrum, Wavelength [2,300] tensor \n        light_spectrum                        : torch.tensor\n                                                Spectrum, Wavelength [2,300] tensor \n\n\n        Returns\n        -------\n        response_to_spectrum                  : float\n                                                Response of cone to light spectrum [1x1] \n        \"\"\"\n        response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)\n        response_to_spectrum = torch.sum(response_to_spectrum)\n        return response_to_spectrum.item()\n\n\n    def construct_matrix_lms(self, l_response, m_response, s_response):\n        '''\n        Internal function to calculate cone  response at particular light spectrum. \n\n        Parameters\n        ----------\n        l_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        m_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        s_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n        Returns\n        -------\n        lms_image_tensor                      : torch.tensor\n                                                3x3 LMSrgb tensor\n\n        '''\n        if self.read_spectrum == 'tensor':\n            logging.warning('Tensor primary spectrum is used')\n            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n        else:\n            logging.warning(\"No Spectrum data is provided\")\n\n        self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)\n        for i in range(self.primaries_spectrum.shape[0]):\n            self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])\n            self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])\n            self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) \n        return self.lms_tensor    \n\n\n    def construct_matrix_primaries(self, l_response, m_response, s_response):\n        '''\n        Internal function to calculate cone  response at particular light spectrum. \n\n        Parameters\n        ----------\n        l_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        m_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n        s_response                             : torch.tensor\n                                                 Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n        Returns\n        -------\n        lms_image_tensor                      : torch.tensor\n                                                3x3 LMSrgb tensor\n\n        '''\n        if self.read_spectrum == 'tensor':\n            logging.warning('Tensor primary spectrum is used')\n            logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n        else:\n            logging.warning(\"No Spectrum data is provided\")\n\n        self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)\n        for i in range(self.primaries_spectrum.shape[0]):\n            self.primaries_tensor[0, i] = self.cone_response_to_spectrum(\n                                                                         l_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        )\n            self.primaries_tensor[1, i] = self.cone_response_to_spectrum(\n                                                                         m_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        )\n            self.primaries_tensor[2, i] = self.cone_response_to_spectrum(\n                                                                         s_response,\n                                                                         self.primaries_spectrum[i]\n                                                                        ) \n        return self.primaries_tensor    \n\n\n    def primaries_to_lms(self, primaries):\n        \"\"\"\n        Internal function to convert primaries space to LMS space \n\n        Parameters\n        ----------\n        primaries                              : torch.tensor\n                                                 Primaries data to be transformed to LMS space [BxPHxW]\n\n\n        Returns\n        -------\n        lms_color                              : torch.tensor\n                                                 LMS data transformed from Primaries space [BxPxHxW]\n        \"\"\"                \n        primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)\n        lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)\n        lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)\n        return lms_color\n\n\n    def lms_to_primaries(self, lms_color_tensor):\n        \"\"\"\n        Internal function to convert LMS image to primaries space\n\n        Parameters\n        ----------\n        lms_color_tensor                        : torch.tensor\n                                                  LMS data to be transformed to primaries space [Bx3xHxW]\n\n\n        Returns\n        -------\n        primaries                              : torch.tensor\n                                               : Primaries data transformed from LMS space [BxPxHxW]\n        \"\"\"\n        lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)\n        lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)\n        unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))\n        converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())\n        primaries = unflatten(converted_unflatten)     \n        primaries = primaries.permute(0, 3, 1, 2)   \n        return primaries\n\n\n    def second_to_third_stage(self, lms_image):\n        '''\n        This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], \n        See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.\n\n        Parameters\n        ----------\n        lms_image                             : torch.tensor\n                                                 Image data at LMS space (second stage)\n\n        Returns\n        -------\n        third_stage                            : torch.tensor\n                                                 Image data at LMS space (third stage)\n\n        '''\n        third_stage = torch.zeros_like(lms_image)\n        third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]\n        third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]\n        third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]\n        return third_stage\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.__call__","title":"__call__(input_image, ground_truth, gaze=None)","text":"

Evaluating an input image against a target ground truth image for a given gaze of a viewer.

Source code in odak/learn/perception/color_conversion.py
def __call__(self, input_image, ground_truth, gaze=None):\n    \"\"\"\n    Evaluating an input image against a target ground truth image for a given gaze of a viewer.\n    \"\"\"\n    lms_image_second = self.primaries_to_lms(input_image.to(self.device))\n    lms_ground_truth_second = self.primaries_to_lms(ground_truth.to(self.device))\n    lms_image_third = self.second_to_third_stage(lms_image_second)\n    lms_ground_truth_third = self.second_to_third_stage(lms_ground_truth_second)\n    loss_metamer_color = torch.mean((lms_ground_truth_third - lms_image_third) ** 2)\n    return loss_metamer_color\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.__init__","title":"__init__(resolution=[1920, 1080], distance_from_screen=800, pixel_pitch=0.311, read_spectrum='tensor', primaries_spectrum=torch.rand(3, 301), device=torch.device('cpu'))","text":"

Parameters:

  • resolution \u2013
                          Resolution of the display in pixels.\n
  • distance_from_screen \u2013
                          Distance from the screen in mm.\n
  • pixel_pitch \u2013
                          Pixel pitch of the display in mm.\n
  • read_spectrum \u2013
                          Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n
  • device \u2013
                          Device to run the code on. Default is None which means the code will run on CPU.\n
Source code in odak/learn/perception/color_conversion.py
def __init__(\n             self,\n             resolution = [1920, 1080],\n             distance_from_screen = 800,\n             pixel_pitch = 0.311,\n             read_spectrum = 'tensor',\n             primaries_spectrum = torch.rand(3, 301),\n             device = torch.device('cpu')):\n    '''\n    Parameters\n    ----------\n    resolution                  : list\n                                  Resolution of the display in pixels.\n    distance_from_screen        : int\n                                  Distance from the screen in mm.\n    pixel_pitch                 : float\n                                  Pixel pitch of the display in mm.\n    read_spectrum               : str\n                                  Spectrum of the display. Default is 'default' which is the spectrum of the Dell U2415 display.\n    device                      : torch.device\n                                  Device to run the code on. Default is None which means the code will run on CPU.\n\n    '''\n    self.device = device\n    self.read_spectrum = read_spectrum\n    self.primaries_spectrum = primaries_spectrum.to(self.device)\n    self.resolution = resolution\n    self.distance_from_screen = distance_from_screen\n    self.pixel_pitch = pixel_pitch\n    self.l_normalized, self.m_normalized, self.s_normalized = self.initialize_cones_normalized()\n    self.lms_tensor = self.construct_matrix_lms(\n                                                self.l_normalized,\n                                                self.m_normalized,\n                                                self.s_normalized\n                                               )   \n    self.primaries_tensor = self.construct_matrix_primaries(\n                                                            self.l_normalized,\n                                                            self.m_normalized,\n                                                            self.s_normalized\n                                                           )   \n    return\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.cone_response_to_spectrum","title":"cone_response_to_spectrum(cone_spectrum, light_spectrum)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • cone_spectrum \u2013
                                    Spectrum, Wavelength [2,300] tensor\n
  • light_spectrum \u2013
                                    Spectrum, Wavelength [2,300] tensor\n

Returns:

  • response_to_spectrum ( float ) \u2013

    Response of cone to light spectrum [1x1]

Source code in odak/learn/perception/color_conversion.py
def cone_response_to_spectrum(self, cone_spectrum, light_spectrum):\n    \"\"\"\n    Internal function to calculate cone response at particular light spectrum. \n\n    Parameters\n    ----------\n    cone_spectrum                         : torch.tensor\n                                            Spectrum, Wavelength [2,300] tensor \n    light_spectrum                        : torch.tensor\n                                            Spectrum, Wavelength [2,300] tensor \n\n\n    Returns\n    -------\n    response_to_spectrum                  : float\n                                            Response of cone to light spectrum [1x1] \n    \"\"\"\n    response_to_spectrum = torch.mul(cone_spectrum, light_spectrum)\n    response_to_spectrum = torch.sum(response_to_spectrum)\n    return response_to_spectrum.item()\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.construct_matrix_lms","title":"construct_matrix_lms(l_response, m_response, s_response)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • m_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • s_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n

Returns:

  • lms_image_tensor ( tensor ) \u2013

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_lms(self, l_response, m_response, s_response):\n    '''\n    Internal function to calculate cone  response at particular light spectrum. \n\n    Parameters\n    ----------\n    l_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    m_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    s_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n    Returns\n    -------\n    lms_image_tensor                      : torch.tensor\n                                            3x3 LMSrgb tensor\n\n    '''\n    if self.read_spectrum == 'tensor':\n        logging.warning('Tensor primary spectrum is used')\n        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n    else:\n        logging.warning(\"No Spectrum data is provided\")\n\n    self.lms_tensor = torch.zeros(self.primaries_spectrum.shape[0], 3).to(self.device)\n    for i in range(self.primaries_spectrum.shape[0]):\n        self.lms_tensor[i, 0] = self.cone_response_to_spectrum(l_response, self.primaries_spectrum[i])\n        self.lms_tensor[i, 1] = self.cone_response_to_spectrum(m_response, self.primaries_spectrum[i])\n        self.lms_tensor[i, 2] = self.cone_response_to_spectrum(s_response, self.primaries_spectrum[i]) \n    return self.lms_tensor    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.construct_matrix_primaries","title":"construct_matrix_primaries(l_response, m_response, s_response)","text":"

Internal function to calculate cone response at particular light spectrum.

Parameters:

  • l_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • m_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n
  • s_response \u2013
                                     Cone response spectrum tensor (normalized response vs wavelength)\n

Returns:

  • lms_image_tensor ( tensor ) \u2013

    3x3 LMSrgb tensor

Source code in odak/learn/perception/color_conversion.py
def construct_matrix_primaries(self, l_response, m_response, s_response):\n    '''\n    Internal function to calculate cone  response at particular light spectrum. \n\n    Parameters\n    ----------\n    l_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    m_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n    s_response                             : torch.tensor\n                                             Cone response spectrum tensor (normalized response vs wavelength)\n\n\n\n    Returns\n    -------\n    lms_image_tensor                      : torch.tensor\n                                            3x3 LMSrgb tensor\n\n    '''\n    if self.read_spectrum == 'tensor':\n        logging.warning('Tensor primary spectrum is used')\n        logging.warning('The number of primaries used is {}'.format(self.primaries_spectrum.shape[0]))\n    else:\n        logging.warning(\"No Spectrum data is provided\")\n\n    self.primaries_tensor = torch.zeros(3, self.primaries_spectrum.shape[0]).to(self.device)\n    for i in range(self.primaries_spectrum.shape[0]):\n        self.primaries_tensor[0, i] = self.cone_response_to_spectrum(\n                                                                     l_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    )\n        self.primaries_tensor[1, i] = self.cone_response_to_spectrum(\n                                                                     m_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    )\n        self.primaries_tensor[2, i] = self.cone_response_to_spectrum(\n                                                                     s_response,\n                                                                     self.primaries_spectrum[i]\n                                                                    ) \n    return self.primaries_tensor    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.display_spectrum_response","title":"display_spectrum_response(wavelength, function)","text":"

Internal function to provide light spectrum response at particular wavelength

Parameters:

  • wavelength \u2013
                                  Wavelength in nm [400...700]\n
  • function \u2013
                                  Display light spectrum distribution function\n

Returns:

  • ligth_response_dict ( float ) \u2013

    Display light spectrum response value

Source code in odak/learn/perception/color_conversion.py
def display_spectrum_response(wavelength, function):\n    \"\"\"\n    Internal function to provide light spectrum response at particular wavelength\n\n    Parameters\n    ----------\n    wavelength                          : torch.tensor\n                                          Wavelength in nm [400...700]\n    function                            : torch.tensor\n                                          Display light spectrum distribution function\n\n    Returns\n    -------\n    ligth_response_dict                  : float\n                                           Display light spectrum response value\n    \"\"\"\n    wavelength = int(round(wavelength, 0))\n    if wavelength >= 400 and wavelength <= 700:\n        return function[wavelength - 400].item()\n    elif wavelength < 400:\n        return function[0].item()\n    else:\n        return function[300].item()\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.initialize_cones_normalized","title":"initialize_cones_normalized()","text":"

Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values.

Returns:

  • l_cone_n ( tensor ) \u2013

    Normalised L cone distribution.

  • m_cone_n ( tensor ) \u2013

    Normalised M cone distribution.

  • s_cone_n ( tensor ) \u2013

    Normalised S cone distribution.

Source code in odak/learn/perception/color_conversion.py
def initialize_cones_normalized(self):\n    \"\"\"\n    Internal function to initialize normalized L,M,S cones as normal distribution with given sigma, and mu values. \n\n    Returns\n    -------\n    l_cone_n                     : torch.tensor\n                                   Normalised L cone distribution.\n    m_cone_n                     : torch.tensor\n                                   Normalised M cone distribution.\n    s_cone_n                     : torch.tensor\n                                   Normalised S cone distribution.\n    \"\"\"\n    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n    dist_l = 1 / (32.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 567.5) ** 2 / (2 * 32.5 ** 2))\n    dist_m = 1 / (27.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 545.0) ** 2 / (2 * 27.5 ** 2))\n    dist_s = 1 / (17.0 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 447.5) ** 2 / (2 * 17.0 ** 2))\n\n    l_cone_n = dist_l / dist_l.max()\n    m_cone_n = dist_m / dist_m.max()\n    s_cone_n = dist_s / dist_s.max()\n    return l_cone_n, m_cone_n, s_cone_n\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.initialize_random_spectrum_normalized","title":"initialize_random_spectrum_normalized(dataset)","text":"

Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS].

Parameters:

  • dataset \u2013
                                     spectrum value against wavelength\n
Source code in odak/learn/perception/color_conversion.py
def initialize_random_spectrum_normalized(self, dataset):\n    \"\"\"\n    Initialize normalized light spectrum via combination of 3 gaussian distribution curve fitting [L-BFGS]. \n\n    Parameters\n    ----------\n    dataset                                : torch.tensor \n                                             spectrum value against wavelength \n    \"\"\"\n    dataset = torch.swapaxes(dataset, 0, 1)\n    x_spectrum = torch.linspace(400, 700, steps = 301) - 550\n    y_spectrum = torch.from_numpy(np_cpu.interp(x_spectrum, dataset[0].numpy(), dataset[1].numpy()))\n    max_spectrum = torch.max(y_spectrum)\n    y_spectrum /= max_spectrum\n\n    def gaussian(x, A = 1, sigma = 1, centre = 0): return A * \\\n        torch.exp(-(x - centre) ** 2 / (2 * sigma ** 2))\n\n    def function(x, weights): \n        return gaussian(x, *weights[:3]) + gaussian(x, *weights[3:6]) + gaussian(x, *weights[6:9])\n\n    weights = torch.tensor([1.0, 1.0, -0.2, 1.0, 1.0, 0.0, 1.0, 1.0, 0.2], requires_grad = True)\n    optimizer = torch.optim.LBFGS([weights], max_iter = 1000, lr = 0.1, line_search_fn = None)\n\n    def closure():\n        optimizer.zero_grad()\n        output = function(x_spectrum, weights)\n        loss = F.mse_loss(output, y_spectrum)\n        loss.backward()\n        return loss\n    optimizer.step(closure)\n    spectrum = function(x_spectrum, weights)\n    return spectrum.detach().to(self.device)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.initialize_rgb_backlight_spectrum","title":"initialize_rgb_backlight_spectrum()","text":"

Internal function to initialize baclight spectrum for color primaries.

Returns:

  • red_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for red color primary.

  • green_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for green color primary.

  • blue_spectrum ( tensor ) \u2013

    Normalised backlight spectrum for blue color primary.

Source code in odak/learn/perception/color_conversion.py
def initialize_rgb_backlight_spectrum(self):\n    \"\"\"\n    Internal function to initialize baclight spectrum for color primaries. \n\n    Returns\n    -------\n    red_spectrum                 : torch.tensor\n                                   Normalised backlight spectrum for red color primary.\n    green_spectrum               : torch.tensor\n                                   Normalised backlight spectrum for green color primary.\n    blue_spectrum                : torch.tensor\n                                   Normalised backlight spectrum for blue color primary.\n    \"\"\"\n    wavelength_range = torch.linspace(400, 700, steps = 301, device = self.device)\n    red_spectrum = 1 / (14.5 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 650) ** 2 / (2 * 14.5 ** 2))\n    green_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 550) ** 2 / (2 * 12.0 ** 2))\n    blue_spectrum = 1 / (12 * (2 * torch.pi) ** 0.5) * torch.exp(-0.5 * (wavelength_range - 450) ** 2 / (2 * 12.0 ** 2))\n\n    red_spectrum = red_spectrum / red_spectrum.max()\n    green_spectrum = green_spectrum / green_spectrum.max()\n    blue_spectrum = blue_spectrum / blue_spectrum.max()\n\n    return red_spectrum, green_spectrum, blue_spectrum\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.lms_to_primaries","title":"lms_to_primaries(lms_color_tensor)","text":"

Internal function to convert LMS image to primaries space

Parameters:

  • lms_color_tensor \u2013
                                      LMS data to be transformed to primaries space [Bx3xHxW]\n

Returns:

  • primaries ( tensor ) \u2013

    : Primaries data transformed from LMS space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def lms_to_primaries(self, lms_color_tensor):\n    \"\"\"\n    Internal function to convert LMS image to primaries space\n\n    Parameters\n    ----------\n    lms_color_tensor                        : torch.tensor\n                                              LMS data to be transformed to primaries space [Bx3xHxW]\n\n\n    Returns\n    -------\n    primaries                              : torch.tensor\n                                           : Primaries data transformed from LMS space [BxPxHxW]\n    \"\"\"\n    lms_color_tensor = lms_color_tensor.permute(0, 2, 3, 1).to(self.device)\n    lms_color_flatten = torch.flatten(lms_color_tensor, start_dim=0, end_dim=1)\n    unflatten = torch.nn.Unflatten(0, (lms_color_tensor.size(0), lms_color_tensor.size(1)))\n    converted_unflatten = torch.matmul(lms_color_flatten.double(), self.lms_tensor.pinverse().double())\n    primaries = unflatten(converted_unflatten)     \n    primaries = primaries.permute(0, 3, 1, 2)   \n    return primaries\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.primaries_to_lms","title":"primaries_to_lms(primaries)","text":"

Internal function to convert primaries space to LMS space

Parameters:

  • primaries \u2013
                                     Primaries data to be transformed to LMS space [BxPHxW]\n

Returns:

  • lms_color ( tensor ) \u2013

    LMS data transformed from Primaries space [BxPxHxW]

Source code in odak/learn/perception/color_conversion.py
def primaries_to_lms(self, primaries):\n    \"\"\"\n    Internal function to convert primaries space to LMS space \n\n    Parameters\n    ----------\n    primaries                              : torch.tensor\n                                             Primaries data to be transformed to LMS space [BxPHxW]\n\n\n    Returns\n    -------\n    lms_color                              : torch.tensor\n                                             LMS data transformed from Primaries space [BxPxHxW]\n    \"\"\"                \n    primaries_flatten = primaries.reshape(primaries.shape[0], primaries.shape[1], 1, -1)\n    lms = self.lms_tensor.unsqueeze(0).unsqueeze(-1)\n    lms_color = torch.sum(primaries_flatten * lms, axis = 1).reshape(primaries.shape)\n    return lms_color\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.display_color_hvs.second_to_third_stage","title":"second_to_third_stage(lms_image)","text":"

This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.

Parameters:

  • lms_image \u2013
                                     Image data at LMS space (second stage)\n

Returns:

  • third_stage ( tensor ) \u2013

    Image data at LMS space (third stage)

Source code in odak/learn/perception/color_conversion.py
def second_to_third_stage(self, lms_image):\n    '''\n    This function turns second stage [L,M,S] values into third stage [(M+S)-L, (L+S)-M, L+M+S], \n    See table 1 from Schmidt et al. \"Neurobiological hypothesis of color appearance and hue perception,\" Optics Express 2014.\n\n    Parameters\n    ----------\n    lms_image                             : torch.tensor\n                                             Image data at LMS space (second stage)\n\n    Returns\n    -------\n    third_stage                            : torch.tensor\n                                             Image data at LMS space (third stage)\n\n    '''\n    third_stage = torch.zeros_like(lms_image)\n    third_stage[:, 0] = (lms_image[:, 1] + lms_image[:, 2]) - lms_image[:, 1]\n    third_stage[:, 1] = (lms_image[:, 0] + lms_image[:, 2]) - lms_image[:, 1]\n    third_stage[:, 2] = lms_image[:, 0] + lms_image[:, 1]  + lms_image[:, 2]\n    return third_stage\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.color_map","title":"color_map(input_image, target_image, model='Lab Stats')","text":"

Internal function to map the color of an image to another image. Reference: Color transfer between images, Reinhard et al., 2001.

Parameters:

  • input_image \u2013
                  Input image in RGB color space [3 x m x n].\n
  • target_image \u2013

Returns:

  • mapped_image ( Tensor ) \u2013

    Input image with the color the distribution of the target image [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def color_map(input_image, target_image, model = 'Lab Stats'):\n    \"\"\"\n    Internal function to map the color of an image to another image.\n    Reference: Color transfer between images, Reinhard et al., 2001.\n\n    Parameters\n    ----------\n    input_image         : torch.Tensor\n                          Input image in RGB color space [3 x m x n].\n    target_image        : torch.Tensor\n\n    Returns\n    -------\n    mapped_image           : torch.Tensor\n                             Input image with the color the distribution of the target image [3 x m x n].\n    \"\"\"\n    if model == 'Lab Stats':\n        lab_input = srgb_to_lab(input_image)\n        lab_target = srgb_to_lab(target_image)\n        input_mean_L = torch.mean(lab_input[0, :, :])\n        input_mean_a = torch.mean(lab_input[1, :, :])\n        input_mean_b = torch.mean(lab_input[2, :, :])\n        input_std_L = torch.std(lab_input[0, :, :])\n        input_std_a = torch.std(lab_input[1, :, :])\n        input_std_b = torch.std(lab_input[2, :, :])\n        target_mean_L = torch.mean(lab_target[0, :, :])\n        target_mean_a = torch.mean(lab_target[1, :, :])\n        target_mean_b = torch.mean(lab_target[2, :, :])\n        target_std_L = torch.std(lab_target[0, :, :])\n        target_std_a = torch.std(lab_target[1, :, :])\n        target_std_b = torch.std(lab_target[2, :, :])\n        lab_input[0, :, :] = (lab_input[0, :, :] - input_mean_L) * (target_std_L / input_std_L) + target_mean_L\n        lab_input[1, :, :] = (lab_input[1, :, :] - input_mean_a) * (target_std_a / input_std_a) + target_mean_a\n        lab_input[2, :, :] = (lab_input[2, :, :] - input_mean_b) * (target_std_b / input_std_b) + target_mean_b\n        mapped_image = lab_to_srgb(lab_input.permute(1, 2, 0))\n        return mapped_image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.hsv_to_rgb","title":"hsv_to_rgb(image)","text":"

Definition to convert HSV space to RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image \u2013
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_rgb ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def hsv_to_rgb(image):\n\n    \"\"\"\n    Definition to convert HSV space to  RGB color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_rgb       : torch.tensor\n                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    h = image[..., 0, :, :] / (2 * math.pi)\n    s = image[..., 1, :, :]\n    v = image[..., 2, :, :]\n    hi = torch.floor(h * 6) % 6\n    f = ((h * 6) % 6) - hi\n    one = torch.tensor(1.0)\n    p = v * (one - s)\n    q = v * (one - f * s)\n    t = v * (one - (one - f) * s)\n    hi = hi.long()\n    indices = torch.stack([hi, hi + 6, hi + 12], dim=-3)\n    image_rgb = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)\n    image_rgb = torch.gather(image_rgb, -3, indices)\n    return image_rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.lab_to_srgb","title":"lab_to_srgb(image)","text":"

Definition to convert LAB space to SRGB color space.

Parameters:

  • image \u2013
              Input image in LAB color space[3 x m x n]\n

Returns:

  • image_srgb ( tensor ) \u2013

    Output image in SRGB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def lab_to_srgb(image):\n    \"\"\"\n    Definition to convert LAB space to SRGB color space. \n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in LAB color space[3 x m x n]\n    Returns\n    -------\n    image_srgb     : torch.tensor\n                      Output image in SRGB color space [3 x m x n].\n    \"\"\"\n\n    if image.shape[-1] == 3:\n        input_color = image.permute(2, 0, 1)  # C(H*W)\n    else:\n        input_color = image\n    # lab ---> xyz\n    reference_illuminant = torch.tensor([[[0.950428545]], [[1.000000000]], [[1.088900371]]], dtype=torch.float32)\n    y = (input_color[0:1, :, :] + 16) / 116\n    a =  input_color[1:2, :, :] / 500\n    b =  input_color[2:3, :, :] / 200\n    x = y + a\n    z = y - b\n    xyz = torch.cat((x, y, z), 0)\n    delta = 6 / 29\n    factor = 3 * delta * delta\n    xyz = torch.where(xyz > delta,  xyz ** 3, factor * (xyz - 4 / 29))\n    xyz_color = xyz * reference_illuminant\n    # xyz ---> linear rgb\n    a11 = 3.241003275\n    a12 = -1.537398934\n    a13 = -0.498615861\n    a21 = -0.969224334\n    a22 = 1.875930071\n    a23 = 0.041554224\n    a31 = 0.055639423\n    a32 = -0.204011202\n    a33 = 1.057148933\n    A = torch.tensor([[a11, a12, a13],\n                  [a21, a22, a23],\n                  [a31, a32, a33]], dtype=torch.float32)\n\n    xyz_color = xyz_color.permute(2, 0, 1) # C(H*W)\n    linear_rgb_color = torch.matmul(A, xyz_color)\n    linear_rgb_color = linear_rgb_color.permute(1, 2, 0)\n    # linear rgb ---> srgb\n    limit = 0.0031308\n    image_srgb = torch.where(linear_rgb_color > limit, 1.055 * (linear_rgb_color ** (1.0 / 2.4)) - 0.055, 12.92 * linear_rgb_color)\n    return image_srgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.linear_rgb_to_rgb","title":"linear_rgb_to_rgb(image, threshold=0.0031308)","text":"

Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image \u2013
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n
  • threshold \u2013
              Threshold used in calculations.\n

Returns:

  • image_linear ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_rgb(image, threshold = 0.0031308):\n    \"\"\"\n    Definition to convert linear RGB images to RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n    threshold       : float\n                      Threshold used in calculations.\n\n    Returns\n    -------\n    image_linear    : torch.tensor\n                      Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    image_linear =  torch.where(image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image)\n    return image_linear\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.linear_rgb_to_xyz","title":"linear_rgb_to_xyz(image)","text":"

Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image \u2013
              Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_xyz ( tensor ) \u2013

    Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def linear_rgb_to_xyz(image):\n    \"\"\"\n    Definition to convert RGB space to CIE XYZ color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in linear RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_xyz       : torch.tensor\n                      Output image in XYZ (CIE 1931) color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    a11 = 0.412453\n    a12 = 0.357580\n    a13 = 0.180423\n    a21 = 0.212671\n    a22 = 0.715160\n    a23 = 0.072169\n    a31 = 0.019334\n    a32 = 0.119193\n    a33 = 0.950227\n    M = torch.tensor([[a11, a12, a13], \n                      [a21, a22, a23],\n                      [a31, a32, a33]])\n    size = image.size()\n    image = image.reshape(size[0], size[1], size[2]*size[3])  # NC(HW)\n    image_xyz = torch.matmul(M, image)\n    image_xyz = image_xyz.reshape(size[0], size[1], size[2], size[3])\n    return image_xyz\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.rgb_2_ycrcb","title":"rgb_2_ycrcb(image)","text":"

Converts an image from RGB colourspace to YCrCb colourspace.

Parameters:

  • image \u2013
      Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n

Returns:

  • ycrcb ( tensor ) \u2013

    Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_2_ycrcb(image):\n    \"\"\"\n    Converts an image from RGB colourspace to YCrCb colourspace.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Input image. Should be an RGB floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n\n    Returns\n    -------\n\n    ycrcb   : torch.tensor\n              Image converted to YCrCb colourspace [k x 3 m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n       image = image.unsqueeze(0)\n    ycrcb = torch.zeros(image.size()).to(image.device)\n    ycrcb[:, 0, :, :] = 0.299 * image[:, 0, :, :] + 0.587 * \\\n        image[:, 1, :, :] + 0.114 * image[:, 2, :, :]\n    ycrcb[:, 1, :, :] = 0.5 + 0.713 * (image[:, 0, :, :] - ycrcb[:, 0, :, :])\n    ycrcb[:, 2, :, :] = 0.5 + 0.564 * (image[:, 2, :, :] - ycrcb[:, 0, :, :])\n    return ycrcb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.rgb_to_hsv","title":"rgb_to_hsv(image, eps=1e-08)","text":"

Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html

Parameters:

  • image \u2013
              Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_hsv ( tensor ) \u2013

    Output image in RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_hsv(image, eps: float = 1e-8):\n\n    \"\"\"\n    Definition to convert RGB space to HSV color space. Mostly inspired from : https://kornia.readthedocs.io/en/latest/_modules/kornia/color/hsv.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in HSV color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_hsv       : torch.tensor\n                      Output image in  RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    max_rgb, argmax_rgb = image.max(-3)\n    min_rgb, argmin_rgb = image.min(-3)\n    deltac = max_rgb - min_rgb\n    v = max_rgb\n    s = deltac / (max_rgb + eps)\n    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)\n    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)\n    h1 = bc - gc\n    h2 = (rc - bc) + 2.0 * deltac\n    h3 = (gc - rc) + 4.0 * deltac\n    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)\n    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)\n    h = (h / 6.0) % 1.0\n    h = 2.0 * math.pi * h \n    image_hsv = torch.stack((h, s, v), dim=-3)\n    return image_hsv\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.rgb_to_linear_rgb","title":"rgb_to_linear_rgb(image, threshold=0.0031308)","text":"

Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html

Parameters:

  • image \u2013
              Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n
  • threshold \u2013
              Threshold used in calculations.\n

Returns:

  • image_linear ( tensor ) \u2013

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def rgb_to_linear_rgb(image, threshold = 0.0031308):\n    \"\"\"\n    Definition to convert RGB images to linear RGB color space. Mostly inspired from: https://kornia.readthedocs.io/en/latest/_modules/kornia/color/rgb.html\n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in RGB color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n    threshold       : float\n                      Threshold used in calculations.\n\n    Returns\n    -------\n    image_linear    : torch.tensor\n                      Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    image_linear = torch.where(image > 0.04045, torch.pow(((image + 0.055) / 1.055), 2.4), image / 12.92)\n    return image_linear\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.srgb_to_lab","title":"srgb_to_lab(image)","text":"

Definition to convert SRGB space to LAB color space.

Parameters:

  • image \u2013
              Input image in SRGB color space[3 x m x n]\n

Returns:

  • image_lab ( tensor ) \u2013

    Output image in LAB color space [3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def srgb_to_lab(image):    \n    \"\"\"\n    Definition to convert SRGB space to LAB color space. \n\n    Parameters\n    ----------\n    image           : torch.tensor\n                      Input image in SRGB color space[3 x m x n]\n    Returns\n    -------\n    image_lab       : torch.tensor\n                      Output image in LAB color space [3 x m x n].\n    \"\"\"\n    if image.shape[-1] == 3:\n        input_color = image.permute(2, 0, 1)  # C(H*W)\n    else:\n        input_color = image\n    # rgb ---> linear rgb\n    limit = 0.04045        \n    # linear rgb ---> xyz\n    linrgb_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)\n\n    a11 = 10135552 / 24577794\n    a12 = 8788810  / 24577794\n    a13 = 4435075  / 24577794\n    a21 = 2613072  / 12288897\n    a22 = 8788810  / 12288897\n    a23 = 887015   / 12288897\n    a31 = 1425312  / 73733382\n    a32 = 8788810  / 73733382\n    a33 = 70074185 / 73733382\n\n    A = torch.tensor([[a11, a12, a13],\n                    [a21, a22, a23],\n                    [a31, a32, a33]], dtype=torch.float32)\n\n    linrgb_color = linrgb_color.permute(2, 0, 1) # C(H*W)\n    xyz_color = torch.matmul(A, linrgb_color)\n    xyz_color = xyz_color.permute(1, 2, 0)\n    # xyz ---> lab\n    inv_reference_illuminant = torch.tensor([[[1.052156925]], [[1.000000000]], [[0.918357670]]], dtype=torch.float32)\n    input_color = xyz_color * inv_reference_illuminant\n    delta = 6 / 29\n    delta_square = delta * delta\n    delta_cube = delta * delta_square\n    factor = 1 / (3 * delta_square)\n\n    input_color = torch.where(input_color > delta_cube, torch.pow(input_color, 1 / 3), (factor * input_color + 4 / 29))\n\n    l = 116 * input_color[1:2, :, :] - 16\n    a = 500 * (input_color[0:1,:, :] - input_color[1:2, :, :])\n    b = 200 * (input_color[1:2, :, :] - input_color[2:3, :, :])\n\n    image_lab = torch.cat((l, a, b), 0)\n    return image_lab    \n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.xyz_to_linear_rgb","title":"xyz_to_linear_rgb(image)","text":"

Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)

Parameters:

  • image \u2013
               Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n

Returns:

  • image_linear_rgb ( tensor ) \u2013

    Output image in linear RGB color space [k x 3 x m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def xyz_to_linear_rgb(image):\n    \"\"\"\n    Definition to convert CIE XYZ space to linear RGB color space. Mostly inspired from : Rochester IT Color Conversion Algorithms (https://www.cs.rit.edu/~ncs/color/)\n\n    Parameters\n    ----------\n    image            : torch.tensor\n                       Input image in XYZ (CIE 1931) color space [k x 3 x m x n] or [3 x m x n]. Image(s) must be normalized between zero and one.\n\n    Returns\n    -------\n    image_linear_rgb : torch.tensor\n                       Output image in linear RGB  color space [k x 3 x m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    a11 = 3.240479\n    a12 = -1.537150\n    a13 = -0.498535\n    a21 = -0.969256 \n    a22 = 1.875992 \n    a23 = 0.041556\n    a31 = 0.055648\n    a32 = -0.204043\n    a33 = 1.057311\n    M = torch.tensor([[a11, a12, a13], \n                      [a21, a22, a23],\n                      [a31, a32, a33]])\n    size = image.size()\n    image = image.reshape(size[0], size[1], size[2]*size[3])\n    image_linear_rgb = torch.matmul(M, image)\n    image_linear_rgb = image_linear_rgb.reshape(size[0], size[1], size[2], size[3])\n    return image_linear_rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.color_conversion.ycrcb_2_rgb","title":"ycrcb_2_rgb(image)","text":"

Converts an image from YCrCb colourspace to RGB colourspace.

Parameters:

  • image \u2013
      Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n

Returns:

  • rgb ( tensor ) \u2013

    Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].

Source code in odak/learn/perception/color_conversion.py
def ycrcb_2_rgb(image):\n    \"\"\"\n    Converts an image from YCrCb colourspace to RGB colourspace.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Input image. Should be a YCrCb floating-point image with values in the range [0, 1]. Should be in NCHW format [3 x m x n] or [k x 3 x m x n].\n\n    Returns\n    -------\n    rgb     : torch.tensor\n              Image converted to RGB colourspace [k x 3 m x n] or [1 x 3 x m x n].\n    \"\"\"\n    if len(image.shape) == 3:\n       image = image.unsqueeze(0)\n    rgb = torch.zeros(image.size(), device=image.device)\n    rgb[:, 0, :, :] = image[:, 0, :, :] + 1.403 * (image[:, 1, :, :] - 0.5)\n    rgb[:, 1, :, :] = image[:, 0, :, :] - 0.714 * \\\n        (image[:, 1, :, :] - 0.5) - 0.344 * (image[:, 2, :, :] - 0.5)\n    rgb[:, 2, :, :] = image[:, 0, :, :] + 1.773 * (image[:, 2, :, :] - 0.5)\n    return rgb\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_3d_location_map","title":"make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)","text":"

Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction.

Parameters:

  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • map ( tensor ) \u2013

    The computed 3D location map, of size 3xWxH.

Source code in odak/learn/perception/foveation.py
def make_3d_location_map(image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):\n    \"\"\" \n    Makes a map of the real 3D location that each pixel in an image corresponds to, when displayed to\n    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction.\n\n    Parameters\n    ----------\n\n    image_pixel_size        : tuple of ints \n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float \n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    map                     : torch.tensor\n                                The computed 3D location map, of size 3xWxH.\n    \"\"\"\n    real_image_height = (real_image_width /\n                         image_pixel_size[-1]) * image_pixel_size[-2]\n    x_coords = torch.linspace(-0.5, 0.5, image_pixel_size[-1])*real_image_width\n    x_coords = x_coords[None, None, :].repeat(1, image_pixel_size[-2], 1)\n    y_coords = torch.linspace(-0.5, 0.5,\n                              image_pixel_size[-2])*real_image_height\n    y_coords = y_coords[None, :, None].repeat(1, 1, image_pixel_size[-1])\n    z_coords = torch.ones(\n        (1, image_pixel_size[-2], image_pixel_size[-1])) * real_viewing_distance\n\n    return torch.cat([x_coords, y_coords, z_coords], dim=0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_eccentricity_distance_maps","title":"make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6)","text":"

Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output in radians.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • eccentricity_map ( tensor ) \u2013

    The computed eccentricity map, of size WxH.

  • distance_map ( tensor ) \u2013

    The computed distance map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_eccentricity_distance_maps(gaze_location, image_pixel_size, real_image_width=0.3, real_viewing_distance=0.6):\n    \"\"\" \n    Makes a map of the eccentricity of each pixel in an image for a given fixation point, when displayed to\n    a user on a flat screen. Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction. Output in radians.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    eccentricity_map        : torch.tensor\n                                The computed eccentricity map, of size WxH.\n    distance_map            : torch.tensor\n                                The computed distance map, of size WxH.\n    \"\"\"\n    real_image_height = (real_image_width /\n                         image_pixel_size[-1]) * image_pixel_size[-2]\n    location_map = make_3d_location_map(\n        image_pixel_size, real_image_width, real_viewing_distance)\n    distance_map = torch.sqrt(torch.sum(location_map*location_map, dim=0))\n    direction_map = location_map / distance_map\n\n    gaze_location_3d = torch.tensor([\n        (gaze_location[0]*2 - 1)*real_image_width*0.5,\n        (gaze_location[1]*2 - 1)*real_image_height*0.5,\n        real_viewing_distance])\n    gaze_dir = gaze_location_3d / \\\n        torch.sqrt(torch.sum(gaze_location_3d * gaze_location_3d))\n    gaze_dir = gaze_dir[:, None, None]\n\n    dot_prod_map = torch.sum(gaze_dir * direction_map, dim=0)\n    dot_prod_map = torch.clamp(dot_prod_map, min=-1.0, max=1.0)\n    eccentricity_map = torch.acos(dot_prod_map)\n\n    return eccentricity_map, distance_map\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_equi_pooling_size_map_lod","title":"make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')","text":"

This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_angles \u2013
                    Gaze direction expressed as angles, in radians.\n
  • image_pixel_size \u2013
                    Dimensions of the image in pixels, as a tuple of (height, width)\n
  • alpha \u2013
                    Parameter controlling extent of foveation\n
  • mode \u2013
                    Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size HxW.

Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_lod(gaze_angles, image_pixel_size, alpha=0.3, mode=\"quadratic\"):\n    \"\"\" \n    This function is similar to make_equi_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from\n    to achieve the correct pooling region areas.\n\n    Parameters\n    ----------\n\n    gaze_angles         : tuple of 2 floats\n                            Gaze direction expressed as angles, in radians.\n    image_pixel_size    : tuple of 2 ints\n                            Dimensions of the image in pixels, as a tuple of (height, width)\n    alpha               : float\n                            Parameter controlling extent of foveation\n    mode                : str\n                            Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size HxW.\n    \"\"\"\n    pooling_pixel = make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha, mode)\n    import matplotlib.pyplot as plt\n    pooling_lod = torch.log2(1e-6+pooling_pixel)\n    pooling_lod[pooling_lod < 0] = 0\n    return pooling_lod\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_equi_pooling_size_map_pixels","title":"make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode='quadratic')","text":"

This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images. Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image corresponds to pitch, ranging from -pi/2 to pi/2.

In this setup real_image_width and real_viewing_distance have no effect.

Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).

Parameters:

  • gaze_angles \u2013
                    Gaze direction expressed as angles, in radians.\n
  • image_pixel_size \u2013
                    Dimensions of the image in pixels, as a tuple of (height, width)\n
  • alpha \u2013
                    Parameter controlling extent of foveation\n
  • mode \u2013
                    Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n
Source code in odak/learn/perception/foveation.py
def make_equi_pooling_size_map_pixels(gaze_angles, image_pixel_size, alpha=0.3, mode=\"quadratic\"):\n    \"\"\"\n    This function makes a map of pooling sizes in pixels, similarly to make_pooling_size_map_pixels, but works on 360 equirectangular images.\n    Input images are assumed to be in equirectangular form - i.e. if you consider a 3D viewing setup where y is the vertical axis, \n    the x location in the image corresponds to rotation around the y axis (yaw), ranging from -pi to pi. The y location in the image\n    corresponds to pitch, ranging from -pi/2 to pi/2.\n\n    In this setup real_image_width and real_viewing_distance have no effect.\n\n    Note that rather than a 2D image gaze location in [0,1]^2, the gaze should be specified as gaze angles in [-pi,pi]x[-pi/2,pi/2] (yaw, then pitch).\n\n    Parameters\n    ----------\n\n    gaze_angles         : tuple of 2 floats\n                            Gaze direction expressed as angles, in radians.\n    image_pixel_size    : tuple of 2 ints\n                            Dimensions of the image in pixels, as a tuple of (height, width)\n    alpha               : float\n                            Parameter controlling extent of foveation\n    mode                : str\n                            Foveation mode (how pooling size varies with eccentricity). Should be \"quadratic\" or \"linear\"\n    \"\"\"\n    view_direction = torch.tensor([math.sin(gaze_angles[0])*math.cos(gaze_angles[1]), math.sin(gaze_angles[1]), math.cos(gaze_angles[0])*math.cos(gaze_angles[1])])\n\n    yaw_angle_map = torch.linspace(-torch.pi, torch.pi, image_pixel_size[1])\n    yaw_angle_map = yaw_angle_map[None,:].repeat(image_pixel_size[0], 1)[None,...]\n    pitch_angle_map = torch.linspace(-torch.pi*0.5, torch.pi*0.5, image_pixel_size[0])\n    pitch_angle_map = pitch_angle_map[:,None].repeat(1, image_pixel_size[1])[None,...]\n\n    dir_map = torch.cat([torch.sin(yaw_angle_map)*torch.cos(pitch_angle_map), torch.sin(pitch_angle_map), torch.cos(yaw_angle_map)*torch.cos(pitch_angle_map)])\n\n    # Work out the pooling region diameter in radians\n    view_dot_dir = torch.sum(view_direction[:,None,None] * dir_map, dim=0)\n    eccentricity = torch.acos(view_dot_dir)\n    pooling_rad = alpha * eccentricity\n    if mode == \"quadratic\":\n        pooling_rad *= eccentricity\n\n    # The actual pooling region will be an ellipse in the equirectangular image - the length of the major & minor axes\n    # depend on the x & y resolution of the image. We find these two axis lengths (in pixels) and then the area of the ellipse\n    pixels_per_rad_x = image_pixel_size[1] / (2*torch.pi)\n    pixels_per_rad_y = image_pixel_size[0] / (torch.pi)\n    pooling_axis_x = pooling_rad * pixels_per_rad_x\n    pooling_axis_y = pooling_rad * pixels_per_rad_y\n    area = torch.pi * pooling_axis_x * pooling_axis_y * 0.25\n\n    # Now finally find the length of the side of a square of the same area.\n    size = torch.sqrt(torch.abs(area))\n    return size\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_pooling_size_map_lod","title":"make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')","text":"

This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from to achieve the correct pooling region areas.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_lod(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode=\"quadratic\"):\n    \"\"\" \n    This function is similar to make_pooling_size_map_pixels, but instead returns a map of LOD levels to sample from\n    to achieve the correct pooling region areas.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size WxH.\n    \"\"\"\n    pooling_pixel = make_pooling_size_map_pixels(\n        gaze_location, image_pixel_size, alpha, real_image_width, real_viewing_distance, mode)\n    pooling_lod = torch.log2(1e-6+pooling_pixel)\n    pooling_lod[pooling_lod < 0] = 0\n    return pooling_lod\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_pooling_size_map_pixels","title":"make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode='quadratic')","text":"

Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity (also in radians).

Assumes the viewpoint is located at the centre of the image, and the screen is perpendicular to the viewing direction. Output is the width of the pooling region in pixels.

Parameters:

  • gaze_location \u2013
                        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                    image coordinates (ranging from 0 to 1)\n
  • image_pixel_size \u2013
                        The size of the image in pixels, as a tuple of form (height, width)\n
  • real_image_width \u2013
                        The real width of the image as displayed. Units not important, as long as they\n                    are the same as those used for real_viewing_distance\n
  • real_viewing_distance \u2013
                        The real distance from the user's viewpoint to the screen.\n

Returns:

  • pooling_size_map ( tensor ) \u2013

    The computed pooling size map, of size WxH.

Source code in odak/learn/perception/foveation.py
def make_pooling_size_map_pixels(gaze_location, image_pixel_size, alpha=0.3, real_image_width=0.3, real_viewing_distance=0.6, mode=\"quadratic\"):\n    \"\"\" \n    Makes a map of the pooling size associated with each pixel in an image for a given fixation point, when displayed to\n    a user on a flat screen. Follows the idea that pooling size (in radians) should be directly proportional to eccentricity\n    (also in radians). \n\n    Assumes the viewpoint is located at the centre of the image, and the screen is \n    perpendicular to the viewing direction. Output is the width of the pooling region in pixels.\n\n    Parameters\n    ----------\n\n    gaze_location           : tuple of floats\n                                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                                image coordinates (ranging from 0 to 1)\n    image_pixel_size        : tuple of ints\n                                The size of the image in pixels, as a tuple of form (height, width)\n    real_image_width        : float\n                                The real width of the image as displayed. Units not important, as long as they\n                                are the same as those used for real_viewing_distance\n    real_viewing_distance   : float\n                                The real distance from the user's viewpoint to the screen.\n\n    Returns\n    -------\n\n    pooling_size_map        : torch.tensor\n                                The computed pooling size map, of size WxH.\n    \"\"\"\n    eccentricity, distance_to_pixel = make_eccentricity_distance_maps(\n        gaze_location, image_pixel_size, real_image_width, real_viewing_distance)\n    eccentricity_centre, _ = make_eccentricity_distance_maps(\n        [0.5, 0.5], image_pixel_size, real_image_width, real_viewing_distance)\n    pooling_rad = alpha * eccentricity\n    if mode == \"quadratic\":\n        pooling_rad *= eccentricity\n    angle_min = eccentricity_centre - pooling_rad*0.5\n    angle_max = eccentricity_centre + pooling_rad*0.5\n    major_axis = (torch.tan(angle_max) - torch.tan(angle_min)) * \\\n        real_viewing_distance\n    minor_axis = 2 * distance_to_pixel * torch.tan(pooling_rad*0.5)\n    area = math.pi * major_axis * minor_axis * 0.25\n    # Should be +ve anyway, but check to ensure we don't take sqrt of negative number\n    area = torch.abs(area)\n    pooling_real = torch.sqrt(area)\n    pooling_pixel = (pooling_real / real_image_width) * image_pixel_size[1]\n    return pooling_pixel\n
"},{"location":"odak/learn_perception/#odak.learn.perception.foveation.make_radial_map","title":"make_radial_map(size, gaze)","text":"

Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.

Parameters:

  • size \u2013
        Dimensions of the image\n
  • gaze \u2013
        User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n    image coordinates (ranging from 0 to 1)\n
Source code in odak/learn/perception/foveation.py
def make_radial_map(size, gaze):\n    \"\"\" \n    Makes a simple radial map where each pixel contains distance in pixels from the chosen gaze location.\n\n    Parameters\n    ----------\n\n    size    : tuple of ints\n                Dimensions of the image\n    gaze    : tuple of floats\n                User's gaze (fixation point) in the image. Should be given as a tuple with normalized\n                image coordinates (ranging from 0 to 1)\n    \"\"\"\n    pix_gaze = [gaze[0]*size[0], gaze[1]*size[1]]\n    rows = torch.linspace(0, size[0], size[0])\n    rows = rows[:, None].repeat(1, size[1])\n    cols = torch.linspace(0, size[1], size[1])\n    cols = cols[None, :].repeat(size[0], 1)\n    dist_sq = torch.pow(rows - pix_gaze[0], 2) + \\\n        torch.pow(cols - pix_gaze[1], 2)\n    radii = torch.sqrt(dist_sq)\n    return radii/torch.max(radii)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.MSSSIM","title":"MSSSIM","text":"

Bases: Module

A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class MSSSIM(nn.Module):\n    '''\n    A class to calculate multi-scale structural similarity index of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(MSSSIM, self).__init__()\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions : torch.tensor\n                      The predicted images.\n        targets     : torch.tensor\n                      The ground truth images.\n\n        Returns\n        -------\n        result      : torch.tensor \n                      The computed MS-SSIM value if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            from torchmetrics.functional.image import multiscale_structural_similarity_index_measure\n            if len(predictions.shape) == 3:\n                predictions = predictions.unsqueeze(0)\n                targets = targets.unsqueeze(0)\n            l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)\n            return l_MSSSIM  \n        except Exception as e:\n            logging.warning('MS-SSIM failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.MSSSIM.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions (tensor) \u2013
          The predicted images.\n
  • targets \u2013
          The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed MS-SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions : torch.tensor\n                  The predicted images.\n    targets     : torch.tensor\n                  The ground truth images.\n\n    Returns\n    -------\n    result      : torch.tensor \n                  The computed MS-SSIM value if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        from torchmetrics.functional.image import multiscale_structural_similarity_index_measure\n        if len(predictions.shape) == 3:\n            predictions = predictions.unsqueeze(0)\n            targets = targets.unsqueeze(0)\n        l_MSSSIM = multiscale_structural_similarity_index_measure(predictions, targets, data_range = 1.0)\n        return l_MSSSIM  \n    except Exception as e:\n        logging.warning('MS-SSIM failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.PSNR","title":"PSNR","text":"

Bases: Module

A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class PSNR(nn.Module):\n    '''\n    A class to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(PSNR, self).__init__()\n\n    def forward(self, predictions, targets, peak_value = 1.0):\n        \"\"\"\n        A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        Image to be tested.\n        targets       : torch.tensor\n                        Ground truth image.\n        peak_value    : float\n                        Peak value that given tensors could have.\n\n        Returns\n        -------\n        result        : torch.tensor\n                        Peak-signal-to-noise ratio.\n        \"\"\"\n        mse = torch.mean((targets - predictions) ** 2)\n        result = 20 * torch.log10(peak_value / torch.sqrt(mse))\n        return result\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.PSNR.forward","title":"forward(predictions, targets, peak_value=1.0)","text":"

A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.

Parameters:

  • predictions \u2013
            Image to be tested.\n
  • targets \u2013
            Ground truth image.\n
  • peak_value \u2013
            Peak value that given tensors could have.\n

Returns:

  • result ( tensor ) \u2013

    Peak-signal-to-noise ratio.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets, peak_value = 1.0):\n    \"\"\"\n    A function to calculate peak-signal-to-noise ratio of an image with respect to a ground truth image.\n\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    Image to be tested.\n    targets       : torch.tensor\n                    Ground truth image.\n    peak_value    : float\n                    Peak value that given tensors could have.\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Peak-signal-to-noise ratio.\n    \"\"\"\n    mse = torch.mean((targets - predictions) ** 2)\n    result = 20 * torch.log10(peak_value / torch.sqrt(mse))\n    return result\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.SSIM","title":"SSIM","text":"

Bases: Module

A class to calculate structural similarity index of an image with respect to a ground truth image.

Source code in odak/learn/perception/image_quality_losses.py
class SSIM(nn.Module):\n    '''\n    A class to calculate structural similarity index of an image with respect to a ground truth image.\n    '''\n\n    def __init__(self):\n        super(SSIM, self).__init__()\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions : torch.tensor\n                      The predicted images.\n        targets     : torch.tensor\n                      The ground truth images.\n\n        Returns\n        -------\n        result      : torch.tensor \n                      The computed SSIM value if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            from torchmetrics.functional.image import structural_similarity_index_measure\n            if len(predictions.shape) == 3:\n                predictions = predictions.unsqueeze(0)\n                targets = targets.unsqueeze(0)\n            l_SSIM = structural_similarity_index_measure(predictions, targets)\n            return l_SSIM\n        except Exception as e:\n            logging.warning('SSIM failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.image_quality_losses.SSIM.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions (tensor) \u2013
          The predicted images.\n
  • targets \u2013
          The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed SSIM value if successful, otherwise 0.0.

Source code in odak/learn/perception/image_quality_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions : torch.tensor\n                  The predicted images.\n    targets     : torch.tensor\n                  The ground truth images.\n\n    Returns\n    -------\n    result      : torch.tensor \n                  The computed SSIM value if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        from torchmetrics.functional.image import structural_similarity_index_measure\n        if len(predictions.shape) == 3:\n            predictions = predictions.unsqueeze(0)\n            targets = targets.unsqueeze(0)\n        l_SSIM = structural_similarity_index_measure(predictions, targets)\n        return l_SSIM\n    except Exception as e:\n        logging.warning('SSIM failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.CVVDP","title":"CVVDP","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class CVVDP(nn.Module):\n    def __init__(self, device = torch.device('cpu')):\n        \"\"\"\n        Initializes the CVVDP model with a specified device.\n\n        Parameters\n        ----------\n        device   : torch.device\n                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n        \"\"\"\n        super(CVVDP, self).__init__()\n        try:\n            import pycvvdp\n            self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)\n        except Exception as e:\n            logging.warning('ColorVideoVDP is missing, consider installing by running \"pip install -U git+https://github.com/gfxdisp/ColorVideoVDP\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets, dim_order = 'CHW'):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets    h  : torch.tensor\n                        The ground truth images.\n        dim_order     : str\n                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n        Returns\n        -------\n        result        : torch.tensor\n                        The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)\n            return l_ColorVideoVDP\n        except Exception as e:\n            logging.warning('ColorVideoVDP failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.CVVDP.__init__","title":"__init__(device=torch.device('cpu'))","text":"

Initializes the CVVDP model with a specified device.

Parameters:

  • device \u2013
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):\n    \"\"\"\n    Initializes the CVVDP model with a specified device.\n\n    Parameters\n    ----------\n    device   : torch.device\n                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n    \"\"\"\n    super(CVVDP, self).__init__()\n    try:\n        import pycvvdp\n        self.cvvdp = pycvvdp.cvvdp(display_name = 'standard_4k', device = device)\n    except Exception as e:\n        logging.warning('ColorVideoVDP is missing, consider installing by running \"pip install -U git+https://github.com/gfxdisp/ColorVideoVDP\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.CVVDP.forward","title":"forward(predictions, targets, dim_order='CHW')","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n
  • dim_order \u2013
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'CHW'):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets    h  : torch.tensor\n                    The ground truth images.\n    dim_order     : str\n                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n    Returns\n    -------\n    result        : torch.tensor\n                    The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        l_ColorVideoVDP = self.cvvdp.loss(predictions, targets, dim_order = dim_order)\n        return l_ColorVideoVDP\n    except Exception as e:\n        logging.warning('ColorVideoVDP failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.FVVDP","title":"FVVDP","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class FVVDP(nn.Module):\n    def __init__(self, device = torch.device('cpu')):\n        \"\"\"\n        Initializes the FVVDP model with a specified device.\n\n        Parameters\n        ----------\n        device   : torch.device\n                    The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n        \"\"\"\n        super(FVVDP, self).__init__()\n        try:\n            import pyfvvdp\n            self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)\n        except Exception as e:\n            logging.warning('FovVideoVDP is missing, consider installing by running \"pip install pyfvvdp\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets, dim_order = 'CHW'):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets       : torch.tensor\n                        The ground truth images.\n        dim_order     : str\n                        The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n        Returns\n        -------\n        result        : torch.tensor\n                          The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]\n            return l_FovVideoVDP\n        except Exception as e:\n            logging.warning('FovVideoVDP failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.FVVDP.__init__","title":"__init__(device=torch.device('cpu'))","text":"

Initializes the FVVDP model with a specified device.

Parameters:

  • device \u2013
        The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n
Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self, device = torch.device('cpu')):\n    \"\"\"\n    Initializes the FVVDP model with a specified device.\n\n    Parameters\n    ----------\n    device   : torch.device\n                The device (CPU/GPU) on which the computations will be performed. Defaults to CPU.\n    \"\"\"\n    super(FVVDP, self).__init__()\n    try:\n        import pyfvvdp\n        self.fvvdp = pyfvvdp.fvvdp(display_name = 'standard_4k', heatmap = 'none', device = device)\n    except Exception as e:\n        logging.warning('FovVideoVDP is missing, consider installing by running \"pip install pyfvvdp\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.FVVDP.forward","title":"forward(predictions, targets, dim_order='CHW')","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n
  • dim_order \u2013
            The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets, dim_order = 'CHW'):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets       : torch.tensor\n                    The ground truth images.\n    dim_order     : str\n                    The dimension order of the input images. Defaults to 'CHW' (channels, height, width).\n\n    Returns\n    -------\n    result        : torch.tensor\n                      The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        l_FovVideoVDP = self.fvvdp.predict(predictions, targets, dim_order = dim_order)[0]\n        return l_FovVideoVDP\n    except Exception as e:\n        logging.warning('FovVideoVDP failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.LPIPS","title":"LPIPS","text":"

Bases: Module

Source code in odak/learn/perception/learned_perceptual_losses.py
class LPIPS(nn.Module):\n\n    def __init__(self):\n        \"\"\"\n        Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.\n\n        \"\"\"\n        super(LPIPS, self).__init__()\n        try:\n            import torchmetrics\n            self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')\n        except Exception as e:\n            logging.warning('torchmetrics is missing, consider installing by running \"pip install torchmetrics\"')\n            logging.warning(e)\n\n\n    def forward(self, predictions, targets):\n        \"\"\"\n        Parameters\n        ----------\n        predictions   : torch.tensor\n                        The predicted images.\n        targets       : torch.tensor\n                        The ground truth images.\n\n        Returns\n        -------\n        result        : torch.tensor\n                        The computed loss if successful, otherwise 0.0.\n        \"\"\"\n        try:\n            lpips_image = predictions\n            lpips_target = targets\n            if len(lpips_image.shape) == 3:\n                lpips_image = lpips_image.unsqueeze(0)\n                lpips_target = lpips_target.unsqueeze(0)\n            if lpips_image.shape[1] == 1:\n                lpips_image = lpips_image.repeat(1, 3, 1, 1)\n                lpips_target = lpips_target.repeat(1, 3, 1, 1)\n            lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)\n            lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)\n            l_LPIPS = self.lpips(lpips_image, lpips_target)\n            return l_LPIPS\n        except Exception as e:\n            logging.warning('LPIPS failed to compute.')\n            logging.warning(e)\n            return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.LPIPS.__init__","title":"__init__()","text":"

Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.

Source code in odak/learn/perception/learned_perceptual_losses.py
def __init__(self):\n    \"\"\"\n    Initializes the LPIPS (Learned Perceptual Image Patch Similarity) model.\n\n    \"\"\"\n    super(LPIPS, self).__init__()\n    try:\n        import torchmetrics\n        self.lpips = torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity(net_type = 'squeeze')\n    except Exception as e:\n        logging.warning('torchmetrics is missing, consider installing by running \"pip install torchmetrics\"')\n        logging.warning(e)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.learned_perceptual_losses.LPIPS.forward","title":"forward(predictions, targets)","text":"

Parameters:

  • predictions \u2013
            The predicted images.\n
  • targets \u2013
            The ground truth images.\n

Returns:

  • result ( tensor ) \u2013

    The computed loss if successful, otherwise 0.0.

Source code in odak/learn/perception/learned_perceptual_losses.py
def forward(self, predictions, targets):\n    \"\"\"\n    Parameters\n    ----------\n    predictions   : torch.tensor\n                    The predicted images.\n    targets       : torch.tensor\n                    The ground truth images.\n\n    Returns\n    -------\n    result        : torch.tensor\n                    The computed loss if successful, otherwise 0.0.\n    \"\"\"\n    try:\n        lpips_image = predictions\n        lpips_target = targets\n        if len(lpips_image.shape) == 3:\n            lpips_image = lpips_image.unsqueeze(0)\n            lpips_target = lpips_target.unsqueeze(0)\n        if lpips_image.shape[1] == 1:\n            lpips_image = lpips_image.repeat(1, 3, 1, 1)\n            lpips_target = lpips_target.repeat(1, 3, 1, 1)\n        lpips_image = (lpips_image * 2 - 1).clamp(-1, 1)\n        lpips_target = (lpips_target * 2 - 1).clamp(-1, 1)\n        l_LPIPS = self.lpips(lpips_image, lpips_target)\n        return l_LPIPS\n    except Exception as e:\n        logging.warning('LPIPS failed to compute.')\n        logging.warning(e)\n        return torch.tensor(0.0)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss.MetamericLoss","title":"MetamericLoss","text":"

The MetamericLoss class provides a perceptual loss function.

Rather than exactly match the source image to the target, it tries to ensure the source is a metamer to the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metameric_loss.py
class MetamericLoss():\n    \"\"\"\n    The `MetamericLoss` class provides a perceptual loss function.\n\n    Rather than exactly match the source image to the target, it tries to ensure the source is a *metamer* to the target image.\n\n    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,\n                 real_viewing_distance=0.7, n_pyramid_levels=5, mode=\"quadratic\",\n                 n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,\n                 use_fullres_l0=False, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        n_pyramid_levels        : int \n                                    Number of levels of the steerable pyramid. Note that the image is padded\n                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                    too high will slow down the calculation a lot.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        n_orientations          : int \n                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                    Increasing this will increase runtime.\n        use_l2_foveal_loss      : bool \n                                    If true, for all the pixels that have pooling size 1 pixel in the \n                                    largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                                    In practice this gives better results when the loss is used for holography.\n        fovea_weight            : float \n                                    A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n        use_radial_weight       : bool \n                                    If True, will apply a radial weighting when calculating the difference between\n                                    the source and target stats maps. This weights stats closer to the fovea more than those\n                                    further away.\n        use_fullres_l0          : bool \n                                    If true, stats for the lowpass residual are replaced with blurred versions\n                                    of the full-resolution source and target images.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.pyramid_maker = None\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.blurs = None\n        self.n_pyramid_levels = n_pyramid_levels\n        self.n_orientations = n_orientations\n        self.mode = mode\n        self.use_l2_foveal_loss = use_l2_foveal_loss\n        self.fovea_weight = fovea_weight\n        self.use_radial_weight = use_radial_weight\n        self.use_fullres_l0 = use_fullres_l0\n        self.equi = equi\n        if self.use_fullres_l0 and self.use_l2_foveal_loss:\n            raise Exception(\n                \"Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!\")\n\n    def calc_statsmaps(self, image, gaze=None, alpha=0.01, real_image_width=0.3,\n                       real_viewing_distance=0.6, mode=\"quadratic\", equi=False):\n\n        if self.pyramid_maker is None or \\\n                self.pyramid_maker.device != self.device or \\\n                len(self.pyramid_maker.band_filters) != self.n_orientations or\\\n                self.pyramid_maker.filt_h0.size(0) != image.size(1):\n            self.pyramid_maker = SpatialSteerablePyramid(\n                use_bilinear_downup=False, n_channels=image.size(1),\n                device=self.device, n_orientations=self.n_orientations, filter_type=\"cropped\", filter_size=5)\n\n        if self.blurs is None or len(self.blurs) != self.n_pyramid_levels:\n            self.blurs = [RadiallyVaryingBlur()\n                          for i in range(self.n_pyramid_levels)]\n\n        def find_stats(image_pyr_level, blur):\n            image_means = blur.blur(\n                image_pyr_level, alpha, real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)\n            image_meansq = blur.blur(image_pyr_level*image_pyr_level, alpha,\n                                     real_image_width, real_viewing_distance, centre=gaze, mode=mode, equi=self.equi)\n\n            image_vars = image_meansq - (image_means*image_means)\n            image_vars[image_vars < 1e-7] = 1e-7\n            image_std = torch.sqrt(image_vars)\n            if torch.any(torch.isnan(image_means)):\n                print(image_means)\n                raise Exception(\"NaN in image means!\")\n            if torch.any(torch.isnan(image_std)):\n                print(image_std)\n                raise Exception(\"NaN in image stdevs!\")\n            if self.use_fullres_l0:\n                mask = blur.lod_map > 1e-6\n                mask = mask[None, None, ...]\n                if image_means.size(1) > 1:\n                    mask = mask.repeat(1, image_means.size(1), 1, 1)\n                matte = torch.zeros_like(image_means)\n                matte[mask] = 1.0\n                return image_means * matte, image_std * matte\n            return image_means, image_std\n        output_stats = []\n        image_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n        means, variances = find_stats(image_pyramid[0]['h'], self.blurs[0])\n        if self.use_l2_foveal_loss:\n            self.fovea_mask = torch.zeros(image.size(), device=image.device)\n            for i in range(self.fovea_mask.size(1)):\n                self.fovea_mask[0, i, ...] = 1.0 - \\\n                    (self.blurs[0].lod_map / torch.max(self.blurs[0].lod_map))\n                self.fovea_mask[0, i, self.blurs[0].lod_map < 1e-6] = 1.0\n            self.fovea_mask = torch.pow(self.fovea_mask, 10.0)\n            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, scale_factor=0.125, mode=\"area\")\n            #self.fovea_mask     = torch.nn.functional.interpolate(self.fovea_mask, size=(image.size(-2), image.size(-1)), mode=\"bilinear\")\n            periphery_mask = 1.0 - self.fovea_mask\n            self.periphery_mask = periphery_mask.clone()\n            output_stats.append(means * periphery_mask)\n            output_stats.append(variances * periphery_mask)\n        else:\n            output_stats.append(means)\n            output_stats.append(variances)\n\n        for l in range(0, len(image_pyramid)-1):\n            for o in range(len(image_pyramid[l]['b'])):\n                means, variances = find_stats(\n                    image_pyramid[l]['b'][o], self.blurs[l])\n                if self.use_l2_foveal_loss:\n                    output_stats.append(means * periphery_mask)\n                    output_stats.append(variances * periphery_mask)\n                else:\n                    output_stats.append(means)\n                    output_stats.append(variances)\n            if self.use_l2_foveal_loss:\n                periphery_mask = torch.nn.functional.interpolate(\n                    periphery_mask, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n\n        if self.use_l2_foveal_loss:\n            output_stats.append(image_pyramid[-1][\"l\"] * periphery_mask)\n        elif self.use_fullres_l0:\n            output_stats.append(self.blurs[0].blur(\n                image, alpha, real_image_width, real_viewing_distance, gaze, mode))\n        else:\n            output_stats.append(image_pyramid[-1][\"l\"])\n        return output_stats\n\n    def metameric_loss_stats(self, statsmap_a, statsmap_b, gaze):\n        loss = 0.0\n        for a, b in zip(statsmap_a, statsmap_b):\n            if self.use_radial_weight:\n                radii = make_radial_map(\n                    [a.size(-2), a.size(-1)], gaze).to(a.device)\n                weights = 1.1 - (radii * radii * radii * radii)\n                weights = weights[None, None, ...].repeat(1, a.size(1), 1, 1)\n                loss += torch.nn.MSELoss()(weights*a, weights*b)\n            else:\n                loss += torch.nn.MSELoss()(a, b)\n        loss /= len(statsmap_a)\n        return loss\n\n    def visualise_loss_map(self, image_stats):\n        loss_map = torch.zeros(image_stats[0].size()[-2:])\n        for i in range(len(image_stats)):\n            stats = image_stats[i]\n            target_stats = self.target_stats[i]\n            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))\n            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(\n            ), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            loss_map += stat_mse_map[0, 0, ...]\n        self.loss_map = loss_map\n\n    def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace=\"RGB\", visualise_loss=False):\n        \"\"\" \n        Calculates the Metameric Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        image_colorspace    : str\n                                The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                                accepted values: RGB, YCrCb.\n        gaze                : list\n                                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n        visualise_loss      : bool\n                                Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamericLoss\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n        # If input is RGB, convert to YCrCb.\n        if image.size(1) == 3 and image_colorspace == \"RGB\":\n            image = rgb_2_ycrcb(image)\n            target = rgb_2_ycrcb(target)\n        if self.target is None:\n            self.target = torch.zeros(target.shape).to(target.device)\n        if type(target) == type(self.target):\n            if not torch.all(torch.eq(target, self.target)):\n                self.target = target.detach().clone()\n                self.target_stats = self.calc_statsmaps(\n                    self.target,\n                    gaze=gaze,\n                    alpha=self.alpha,\n                    real_image_width=self.real_image_width,\n                    real_viewing_distance=self.real_viewing_distance,\n                    mode=self.mode\n                )\n                self.target = target.detach().clone()\n            image_stats = self.calc_statsmaps(\n                image,\n                gaze=gaze,\n                alpha=self.alpha,\n                real_image_width=self.real_image_width,\n                real_viewing_distance=self.real_viewing_distance,\n                mode=self.mode\n            )\n            if visualise_loss:\n                self.visualise_loss_map(image_stats)\n            if self.use_l2_foveal_loss:\n                peripheral_loss = self.metameric_loss_stats(\n                    image_stats, self.target_stats, gaze)\n                foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)\n                # New weighting - evenly weight fovea and periphery.\n                loss = peripheral_loss + self.fovea_weight * foveal_loss\n            else:\n                loss = self.metameric_loss_stats(\n                    image_stats, self.target_stats, gaze)\n            return loss\n        else:\n            raise Exception(\"Target of incorrect type\")\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss.MetamericLoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5], image_colorspace='RGB', visualise_loss=False)","text":"

Calculates the Metameric Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • image_colorspace \u2013
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                accepted values: RGB, YCrCb.\n
  • gaze \u2013
                    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n
  • visualise_loss \u2013
                    Shows a heatmap indicating which parts of the image contributed most to the loss.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metameric_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5], image_colorspace=\"RGB\", visualise_loss=False):\n    \"\"\" \n    Calculates the Metameric Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    image_colorspace    : str\n                            The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                            accepted values: RGB, YCrCb.\n    gaze                : list\n                            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n    visualise_loss      : bool\n                            Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamericLoss\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n    # If input is RGB, convert to YCrCb.\n    if image.size(1) == 3 and image_colorspace == \"RGB\":\n        image = rgb_2_ycrcb(image)\n        target = rgb_2_ycrcb(target)\n    if self.target is None:\n        self.target = torch.zeros(target.shape).to(target.device)\n    if type(target) == type(self.target):\n        if not torch.all(torch.eq(target, self.target)):\n            self.target = target.detach().clone()\n            self.target_stats = self.calc_statsmaps(\n                self.target,\n                gaze=gaze,\n                alpha=self.alpha,\n                real_image_width=self.real_image_width,\n                real_viewing_distance=self.real_viewing_distance,\n                mode=self.mode\n            )\n            self.target = target.detach().clone()\n        image_stats = self.calc_statsmaps(\n            image,\n            gaze=gaze,\n            alpha=self.alpha,\n            real_image_width=self.real_image_width,\n            real_viewing_distance=self.real_viewing_distance,\n            mode=self.mode\n        )\n        if visualise_loss:\n            self.visualise_loss_map(image_stats)\n        if self.use_l2_foveal_loss:\n            peripheral_loss = self.metameric_loss_stats(\n                image_stats, self.target_stats, gaze)\n            foveal_loss = torch.nn.MSELoss()(self.fovea_mask*image, self.fovea_mask*target)\n            # New weighting - evenly weight fovea and periphery.\n            loss = peripheral_loss + self.fovea_weight * foveal_loss\n        else:\n            loss = self.metameric_loss_stats(\n                image_stats, self.target_stats, gaze)\n        return loss\n    else:\n        raise Exception(\"Target of incorrect type\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss.MetamericLoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, n_pyramid_levels=5, mode='quadratic', n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False, use_fullres_l0=False, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • n_pyramid_levels \u2013
                        Number of levels of the steerable pyramid. Note that the image is padded\n                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                    too high will slow down the calculation a lot.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • n_orientations \u2013
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                    Increasing this will increase runtime.\n
  • use_l2_foveal_loss \u2013
                        If true, for all the pixels that have pooling size 1 pixel in the \n                    largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                    In practice this gives better results when the loss is used for holography.\n
  • fovea_weight \u2013
                        A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n
  • use_radial_weight \u2013
                        If True, will apply a radial weighting when calculating the difference between\n                    the source and target stats maps. This weights stats closer to the fovea more than those\n                    further away.\n
  • use_fullres_l0 \u2013
                        If true, stats for the lowpass residual are replaced with blurred versions\n                    of the full-resolution source and target images.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/metameric_loss.py
def __init__(self, device=torch.device('cpu'), alpha=0.2, real_image_width=0.2,\n             real_viewing_distance=0.7, n_pyramid_levels=5, mode=\"quadratic\",\n             n_orientations=2, use_l2_foveal_loss=True, fovea_weight=20.0, use_radial_weight=False,\n             use_fullres_l0=False, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    n_pyramid_levels        : int \n                                Number of levels of the steerable pyramid. Note that the image is padded\n                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                too high will slow down the calculation a lot.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    n_orientations          : int \n                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                Increasing this will increase runtime.\n    use_l2_foveal_loss      : bool \n                                If true, for all the pixels that have pooling size 1 pixel in the \n                                largest scale will use direct L2 against target rather than pooling over pyramid levels.\n                                In practice this gives better results when the loss is used for holography.\n    fovea_weight            : float \n                                A weight to apply to the foveal region if use_l2_foveal_loss is set to True.\n    use_radial_weight       : bool \n                                If True, will apply a radial weighting when calculating the difference between\n                                the source and target stats maps. This weights stats closer to the fovea more than those\n                                further away.\n    use_fullres_l0          : bool \n                                If true, stats for the lowpass residual are replaced with blurred versions\n                                of the full-resolution source and target images.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.pyramid_maker = None\n    self.alpha = alpha\n    self.real_image_width = real_image_width\n    self.real_viewing_distance = real_viewing_distance\n    self.blurs = None\n    self.n_pyramid_levels = n_pyramid_levels\n    self.n_orientations = n_orientations\n    self.mode = mode\n    self.use_l2_foveal_loss = use_l2_foveal_loss\n    self.fovea_weight = fovea_weight\n    self.use_radial_weight = use_radial_weight\n    self.use_fullres_l0 = use_fullres_l0\n    self.equi = equi\n    if self.use_fullres_l0 and self.use_l2_foveal_loss:\n        raise Exception(\n            \"Can't use use_fullres_l0 and use_l2_foveal_loss options together in MetamericLoss!\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss_uniform.MetamericLossUniform","title":"MetamericLossUniform","text":"

Measures metameric loss between a given image and a metamer of the given target image. This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.

Source code in odak/learn/perception/metameric_loss_uniform.py
class MetamericLossUniform():\n    \"\"\"\n    Measures metameric loss between a given image and a metamer of the given target image.\n    This variant of the metameric loss is not foveated - it applies uniform pooling sizes to the whole input image.\n    \"\"\"\n\n    def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):\n        \"\"\"\n\n        Parameters\n        ----------\n        pooling_size            : int\n                                  Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n        n_pyramid_levels        : int \n                                  Number of levels of the steerable pyramid. Note that the image is padded\n                                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                  too high will slow down the calculation a lot.\n        n_orientations          : int \n                                  Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                  Increasing this will increase runtime.\n\n        \"\"\"\n        self.target = None\n        self.device = device\n        self.pyramid_maker = None\n        self.pooling_size = pooling_size\n        self.n_pyramid_levels = n_pyramid_levels\n        self.n_orientations = n_orientations\n\n    def calc_statsmaps(self, image, pooling_size):\n\n        if self.pyramid_maker is None or \\\n                self.pyramid_maker.device != self.device or \\\n                len(self.pyramid_maker.band_filters) != self.n_orientations or\\\n                self.pyramid_maker.filt_h0.size(0) != image.size(1):\n            self.pyramid_maker = SpatialSteerablePyramid(\n                use_bilinear_downup=False, n_channels=image.size(1),\n                device=self.device, n_orientations=self.n_orientations, filter_type=\"cropped\", filter_size=5)\n\n\n        def find_stats(image_pyr_level, pooling_size):\n            image_means = uniform_blur(image_pyr_level, pooling_size)\n            image_meansq = uniform_blur(image_pyr_level*image_pyr_level, pooling_size)\n            image_vars = image_meansq - (image_means*image_means)\n            image_vars[image_vars < 1e-7] = 1e-7\n            image_std = torch.sqrt(image_vars)\n            if torch.any(torch.isnan(image_means)):\n                print(image_means)\n                raise Exception(\"NaN in image means!\")\n            if torch.any(torch.isnan(image_std)):\n                print(image_std)\n                raise Exception(\"NaN in image stdevs!\")\n            return image_means, image_std\n\n        output_stats = []\n        image_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n        curr_pooling_size = pooling_size\n        means, variances = find_stats(image_pyramid[0]['h'], curr_pooling_size)\n        output_stats.append(means)\n        output_stats.append(variances)\n\n        for l in range(0, len(image_pyramid)-1):\n            for o in range(len(image_pyramid[l]['b'])):\n                means, variances = find_stats(\n                    image_pyramid[l]['b'][o], curr_pooling_size)\n                output_stats.append(means)\n                output_stats.append(variances)\n            curr_pooling_size /= 2\n\n        output_stats.append(image_pyramid[-1][\"l\"])\n        return output_stats\n\n    def metameric_loss_stats(self, statsmap_a, statsmap_b):\n        loss = 0.0\n        for a, b in zip(statsmap_a, statsmap_b):\n            loss += torch.nn.MSELoss()(a, b)\n        loss /= len(statsmap_a)\n        return loss\n\n    def visualise_loss_map(self, image_stats):\n        loss_map = torch.zeros(image_stats[0].size()[-2:])\n        for i in range(len(image_stats)):\n            stats = image_stats[i]\n            target_stats = self.target_stats[i]\n            stat_mse_map = torch.sqrt(torch.pow(stats - target_stats, 2))\n            stat_mse_map = torch.nn.functional.interpolate(stat_mse_map, size=loss_map.size(\n            ), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            loss_map += stat_mse_map[0, 0, ...]\n        self.loss_map = loss_map\n\n    def __call__(self, image, target, image_colorspace=\"RGB\", visualise_loss=False):\n        \"\"\" \n        Calculates the Metameric Loss.\n\n        Parameters\n        ----------\n        image               : torch.tensor\n                                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target              : torch.tensor\n                                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        image_colorspace    : str\n                                The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                                accepted values: RGB, YCrCb.\n        visualise_loss      : bool\n                                Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamericLossUniform\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n        # If input is RGB, convert to YCrCb.\n        if image.size(1) == 3 and image_colorspace == \"RGB\":\n            image = rgb_2_ycrcb(image)\n            target = rgb_2_ycrcb(target)\n        if self.target is None:\n            self.target = torch.zeros(target.shape).to(target.device)\n        if type(target) == type(self.target):\n            if not torch.all(torch.eq(target, self.target)):\n                self.target = target.detach().clone()\n                self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)\n                self.target = target.detach().clone()\n            image_stats = self.calc_statsmaps(image, self.pooling_size)\n\n            if visualise_loss:\n                self.visualise_loss_map(image_stats)\n            loss = self.metameric_loss_stats(\n                image_stats, self.target_stats)\n            return loss\n        else:\n            raise Exception(\"Target of incorrect type\")\n\n    def gen_metamer(self, image):\n        \"\"\" \n        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n        This function can be used on its own to generate a metamer for a desired image.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                  Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n\n        Returns\n        -------\n        metamer : torch.tensor\n                  The generated metamer image\n        \"\"\"\n        image = rgb_2_ycrcb(image)\n        image_size = image.size()\n        image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n\n        target_stats = self.calc_statsmaps(\n            image, self.pooling_size)\n        target_means = target_stats[::2]\n        target_stdevs = target_stats[1::2]\n        torch.manual_seed(0)\n        noise_image = torch.rand_like(image)\n        noise_pyramid = self.pyramid_maker.construct_pyramid(\n            noise_image, self.n_pyramid_levels)\n        input_pyramid = self.pyramid_maker.construct_pyramid(\n            image, self.n_pyramid_levels)\n\n        def match_level(input_level, target_mean, target_std):\n            level = input_level.clone()\n            level -= torch.mean(level)\n            input_std = torch.sqrt(torch.mean(level * level))\n            eps = 1e-6\n            # Safeguard against divide by zero\n            input_std[input_std < eps] = eps\n            level /= input_std\n            level *= target_std\n            level += target_mean\n            return level\n\n        nbands = len(noise_pyramid[0][\"b\"])\n        noise_pyramid[0][\"h\"] = match_level(\n            noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n        for l in range(len(noise_pyramid)-1):\n            for b in range(nbands):\n                noise_pyramid[l][\"b\"][b] = match_level(\n                    noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n        noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n        metamer = self.pyramid_maker.reconstruct_from_pyramid(\n            noise_pyramid)\n        metamer = ycrcb_2_rgb(metamer)\n        # Crop to remove any padding\n        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n        return metamer\n\n    def to(self, device):\n        self.device = device\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss_uniform.MetamericLossUniform.__call__","title":"__call__(image, target, image_colorspace='RGB', visualise_loss=False)","text":"

Calculates the Metameric Loss.

Parameters:

  • image \u2013
                    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
                    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • image_colorspace \u2013
                    The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                accepted values: RGB, YCrCb.\n
  • visualise_loss \u2013
                    Shows a heatmap indicating which parts of the image contributed most to the loss.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metameric_loss_uniform.py
def __call__(self, image, target, image_colorspace=\"RGB\", visualise_loss=False):\n    \"\"\" \n    Calculates the Metameric Loss.\n\n    Parameters\n    ----------\n    image               : torch.tensor\n                            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target              : torch.tensor\n                            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    image_colorspace    : str\n                            The current colorspace of your image and target. Ignored if input does not have 3 channels.\n                            accepted values: RGB, YCrCb.\n    visualise_loss      : bool\n                            Shows a heatmap indicating which parts of the image contributed most to the loss. \n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamericLossUniform\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.n_pyramid_levels)\n    # If input is RGB, convert to YCrCb.\n    if image.size(1) == 3 and image_colorspace == \"RGB\":\n        image = rgb_2_ycrcb(image)\n        target = rgb_2_ycrcb(target)\n    if self.target is None:\n        self.target = torch.zeros(target.shape).to(target.device)\n    if type(target) == type(self.target):\n        if not torch.all(torch.eq(target, self.target)):\n            self.target = target.detach().clone()\n            self.target_stats = self.calc_statsmaps(self.target, self.pooling_size)\n            self.target = target.detach().clone()\n        image_stats = self.calc_statsmaps(image, self.pooling_size)\n\n        if visualise_loss:\n            self.visualise_loss_map(image_stats)\n        loss = self.metameric_loss_stats(\n            image_stats, self.target_stats)\n        return loss\n    else:\n        raise Exception(\"Target of incorrect type\")\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss_uniform.MetamericLossUniform.__init__","title":"__init__(device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2)","text":"

Parameters:

  • pooling_size \u2013
                      Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n
  • n_pyramid_levels \u2013
                      Number of levels of the steerable pyramid. Note that the image is padded\n                  so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                  too high will slow down the calculation a lot.\n
  • n_orientations \u2013
                      Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                  Increasing this will increase runtime.\n
Source code in odak/learn/perception/metameric_loss_uniform.py
def __init__(self, device=torch.device('cpu'), pooling_size=32, n_pyramid_levels=5, n_orientations=2):\n    \"\"\"\n\n    Parameters\n    ----------\n    pooling_size            : int\n                              Pooling size, in pixels. For example 32 will pool over 32x32 blocks of the image.\n    n_pyramid_levels        : int \n                              Number of levels of the steerable pyramid. Note that the image is padded\n                              so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                              too high will slow down the calculation a lot.\n    n_orientations          : int \n                              Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                              Increasing this will increase runtime.\n\n    \"\"\"\n    self.target = None\n    self.device = device\n    self.pyramid_maker = None\n    self.pooling_size = pooling_size\n    self.n_pyramid_levels = n_pyramid_levels\n    self.n_orientations = n_orientations\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metameric_loss_uniform.MetamericLossUniform.gen_metamer","title":"gen_metamer(image)","text":"

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image \u2013
      Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n

Returns:

  • metamer ( tensor ) \u2013

    The generated metamer image

Source code in odak/learn/perception/metameric_loss_uniform.py
def gen_metamer(self, image):\n    \"\"\" \n    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n    This function can be used on its own to generate a metamer for a desired image.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n              Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n\n    Returns\n    -------\n    metamer : torch.tensor\n              The generated metamer image\n    \"\"\"\n    image = rgb_2_ycrcb(image)\n    image_size = image.size()\n    image = pad_image_for_pyramid(image, self.n_pyramid_levels)\n\n    target_stats = self.calc_statsmaps(\n        image, self.pooling_size)\n    target_means = target_stats[::2]\n    target_stdevs = target_stats[1::2]\n    torch.manual_seed(0)\n    noise_image = torch.rand_like(image)\n    noise_pyramid = self.pyramid_maker.construct_pyramid(\n        noise_image, self.n_pyramid_levels)\n    input_pyramid = self.pyramid_maker.construct_pyramid(\n        image, self.n_pyramid_levels)\n\n    def match_level(input_level, target_mean, target_std):\n        level = input_level.clone()\n        level -= torch.mean(level)\n        input_std = torch.sqrt(torch.mean(level * level))\n        eps = 1e-6\n        # Safeguard against divide by zero\n        input_std[input_std < eps] = eps\n        level /= input_std\n        level *= target_std\n        level += target_mean\n        return level\n\n    nbands = len(noise_pyramid[0][\"b\"])\n    noise_pyramid[0][\"h\"] = match_level(\n        noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n    for l in range(len(noise_pyramid)-1):\n        for b in range(nbands):\n            noise_pyramid[l][\"b\"][b] = match_level(\n                noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n    noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n    metamer = self.pyramid_maker.reconstruct_from_pyramid(\n        noise_pyramid)\n    metamer = ycrcb_2_rgb(metamer)\n    # Crop to remove any padding\n    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n    return metamer\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metamer_mse_loss.MetamerMSELoss","title":"MetamerMSELoss","text":"

The MetamerMSELoss class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.

Please note this is different to MetamericLoss which optimises the source image to be any metamer of the target image.

Its interface is similar to other pytorch loss functions, but note that the gaze location must be provided in addition to the source and target images.

Source code in odak/learn/perception/metamer_mse_loss.py
class MetamerMSELoss():\n    \"\"\" \n    The `MetamerMSELoss` class provides a perceptual loss function. This generates a metamer for the target image, and then optimises the source image to be the same as this target image metamer.\n\n    Please note this is different to `MetamericLoss` which optimises the source image to be any metamer of the target image.\n\n    Its interface is similar to other `pytorch` loss functions, but note that the gaze location must be provided in addition to the source and target images.\n    \"\"\"\n\n\n    def __init__(self, device=torch.device(\"cpu\"),\n                 alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\",\n                 n_pyramid_levels=5, n_orientations=2, equi=False):\n        \"\"\"\n        Parameters\n        ----------\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n        n_pyramid_levels        : int \n                                    Number of levels of the steerable pyramid. Note that the image is padded\n                                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                    too high will slow down the calculation a lot.\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        n_orientations          : int \n                                    Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                    Increasing this will increase runtime.\n        equi                    : bool\n                                    If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n        \"\"\"\n        self.target = None\n        self.target_metamer = None\n        self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,\n                                            real_viewing_distance=real_viewing_distance,\n                                            n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)\n        self.loss_func = torch.nn.MSELoss()\n        self.noise = None\n\n    def gen_metamer(self, image, gaze):\n        \"\"\" \n        Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n        This function can be used on its own to generate a metamer for a desired image.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze    : list\n                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        metamer : torch.tensor\n                The generated metamer image\n        \"\"\"\n        image = rgb_2_ycrcb(image)\n        image_size = image.size()\n        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n\n        target_stats = self.metameric_loss.calc_statsmaps(\n            image, gaze=gaze, alpha=self.metameric_loss.alpha)\n        target_means = target_stats[::2]\n        target_stdevs = target_stats[1::2]\n        if self.noise is None or self.noise.size() != image.size():\n            torch.manual_seed(0)\n            noise_image = torch.rand_like(image)\n        noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n            noise_image, self.metameric_loss.n_pyramid_levels)\n        input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n            image, self.metameric_loss.n_pyramid_levels)\n\n        def match_level(input_level, target_mean, target_std):\n            level = input_level.clone()\n            level -= torch.mean(level)\n            input_std = torch.sqrt(torch.mean(level * level))\n            eps = 1e-6\n            # Safeguard against divide by zero\n            input_std[input_std < eps] = eps\n            level /= input_std\n            level *= target_std\n            level += target_mean\n            return level\n\n        nbands = len(noise_pyramid[0][\"b\"])\n        noise_pyramid[0][\"h\"] = match_level(\n            noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n        for l in range(len(noise_pyramid)-1):\n            for b in range(nbands):\n                noise_pyramid[l][\"b\"][b] = match_level(\n                    noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n        noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n        metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(\n            noise_pyramid)\n        metamer = ycrcb_2_rgb(metamer)\n        # Crop to remove any padding\n        metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n        return metamer\n\n    def __call__(self, image, target, gaze=[0.5, 0.5]):\n        \"\"\" \n        Calculates the Metamer MSE Loss.\n\n        Parameters\n        ----------\n        image   : torch.tensor\n                Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        target  : torch.tensor\n                Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n        gaze    : list\n                Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n        Returns\n        -------\n\n        loss                : torch.tensor\n                                The computed loss.\n        \"\"\"\n        check_loss_inputs(\"MetamerMSELoss\", image, target)\n        # Pad image and target if necessary\n        image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n        target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)\n\n        if target is not self.target or self.target is None:\n            self.target_metamer = self.gen_metamer(target, gaze)\n            self.target = target\n\n        return self.loss_func(image, self.target_metamer)\n\n    def to(self, device):\n        self.metameric_loss = self.metameric_loss.to(device)\n        return self\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metamer_mse_loss.MetamerMSELoss.__call__","title":"__call__(image, target, gaze=[0.5, 0.5])","text":"

Calculates the Metamer MSE Loss.

Parameters:

  • image \u2013
    Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • target \u2013
    Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • loss ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/perception/metamer_mse_loss.py
def __call__(self, image, target, gaze=[0.5, 0.5]):\n    \"\"\" \n    Calculates the Metamer MSE Loss.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n            Image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    target  : torch.tensor\n            Ground truth target image to compute loss for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze    : list\n            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    loss                : torch.tensor\n                            The computed loss.\n    \"\"\"\n    check_loss_inputs(\"MetamerMSELoss\", image, target)\n    # Pad image and target if necessary\n    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n    target = pad_image_for_pyramid(target, self.metameric_loss.n_pyramid_levels)\n\n    if target is not self.target or self.target is None:\n        self.target_metamer = self.gen_metamer(target, gaze)\n        self.target = target\n\n    return self.loss_func(image, self.target_metamer)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metamer_mse_loss.MetamerMSELoss.__init__","title":"__init__(device=torch.device('cpu'), alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode='quadratic', n_pyramid_levels=5, n_orientations=2, equi=False)","text":"

Parameters:

  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n
  • n_pyramid_levels \u2013
                        Number of levels of the steerable pyramid. Note that the image is padded\n                    so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                    too high will slow down the calculation a lot.\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • n_orientations \u2013
                        Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                    Increasing this will increase runtime.\n
  • equi \u2013
                        If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The gaze argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n
Source code in odak/learn/perception/metamer_mse_loss.py
def __init__(self, device=torch.device(\"cpu\"),\n             alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, mode=\"quadratic\",\n             n_pyramid_levels=5, n_orientations=2, equi=False):\n    \"\"\"\n    Parameters\n    ----------\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n    n_pyramid_levels        : int \n                                Number of levels of the steerable pyramid. Note that the image is padded\n                                so that both height and width are multiples of 2^(n_pyramid_levels), so setting this value\n                                too high will slow down the calculation a lot.\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    n_orientations          : int \n                                Number of orientations in the steerable pyramid. Can be 1, 2, 4 or 6.\n                                Increasing this will increase runtime.\n    equi                    : bool\n                                If true, run the loss in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The gaze argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n    \"\"\"\n    self.target = None\n    self.target_metamer = None\n    self.metameric_loss = MetamericLoss(device=device, alpha=alpha, real_image_width=real_image_width,\n                                        real_viewing_distance=real_viewing_distance,\n                                        n_pyramid_levels=n_pyramid_levels, n_orientations=n_orientations, use_l2_foveal_loss=False, equi=equi)\n    self.loss_func = torch.nn.MSELoss()\n    self.noise = None\n
"},{"location":"odak/learn_perception/#odak.learn.perception.metamer_mse_loss.MetamerMSELoss.gen_metamer","title":"gen_metamer(image, gaze)","text":"

Generates a metamer for an image, following the method in this paper This function can be used on its own to generate a metamer for a desired image.

Parameters:

  • image \u2013
    Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n
  • gaze \u2013
    Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n

Returns:

  • metamer ( tensor ) \u2013

    The generated metamer image

Source code in odak/learn/perception/metamer_mse_loss.py
def gen_metamer(self, image, gaze):\n    \"\"\" \n    Generates a metamer for an image, following the method in [this paper](https://dl.acm.org/doi/abs/10.1145/3450626.3459943)\n    This function can be used on its own to generate a metamer for a desired image.\n\n    Parameters\n    ----------\n    image   : torch.tensor\n            Image to compute metamer for. Should be an RGB image in NCHW format (4 dimensions)\n    gaze    : list\n            Gaze location in the image, in normalized image coordinates (range [0, 1]) relative to the top left of the image.\n\n    Returns\n    -------\n\n    metamer : torch.tensor\n            The generated metamer image\n    \"\"\"\n    image = rgb_2_ycrcb(image)\n    image_size = image.size()\n    image = pad_image_for_pyramid(image, self.metameric_loss.n_pyramid_levels)\n\n    target_stats = self.metameric_loss.calc_statsmaps(\n        image, gaze=gaze, alpha=self.metameric_loss.alpha)\n    target_means = target_stats[::2]\n    target_stdevs = target_stats[1::2]\n    if self.noise is None or self.noise.size() != image.size():\n        torch.manual_seed(0)\n        noise_image = torch.rand_like(image)\n    noise_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n        noise_image, self.metameric_loss.n_pyramid_levels)\n    input_pyramid = self.metameric_loss.pyramid_maker.construct_pyramid(\n        image, self.metameric_loss.n_pyramid_levels)\n\n    def match_level(input_level, target_mean, target_std):\n        level = input_level.clone()\n        level -= torch.mean(level)\n        input_std = torch.sqrt(torch.mean(level * level))\n        eps = 1e-6\n        # Safeguard against divide by zero\n        input_std[input_std < eps] = eps\n        level /= input_std\n        level *= target_std\n        level += target_mean\n        return level\n\n    nbands = len(noise_pyramid[0][\"b\"])\n    noise_pyramid[0][\"h\"] = match_level(\n        noise_pyramid[0][\"h\"], target_means[0], target_stdevs[0])\n    for l in range(len(noise_pyramid)-1):\n        for b in range(nbands):\n            noise_pyramid[l][\"b\"][b] = match_level(\n                noise_pyramid[l][\"b\"][b], target_means[1 + l * nbands + b], target_stdevs[1 + l * nbands + b])\n    noise_pyramid[-1][\"l\"] = input_pyramid[-1][\"l\"]\n\n    metamer = self.metameric_loss.pyramid_maker.reconstruct_from_pyramid(\n        noise_pyramid)\n    metamer = ycrcb_2_rgb(metamer)\n    # Crop to remove any padding\n    metamer = metamer[:image_size[0], :image_size[1], :image_size[2], :image_size[3]]\n    return metamer\n
"},{"location":"odak/learn_perception/#odak.learn.perception.radially_varying_blur.RadiallyVaryingBlur","title":"RadiallyVaryingBlur","text":"

The RadiallyVaryingBlur class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given alpha parameter value. For more information on how the pooling sizes are computed, please see link coming soon.

The blur is accelerated by generating and sampling from MIP maps of the input image.

This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.

If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.

Source code in odak/learn/perception/radially_varying_blur.py
class RadiallyVaryingBlur():\n    \"\"\" \n\n    The `RadiallyVaryingBlur` class provides a way to apply a radially varying blur to an image. Given a gaze location and information about the image and foveation, it applies a blur that will achieve the proper pooling size. The pooling size is chosen to appear the same at a range of display sizes and viewing distances, for a given `alpha` parameter value. For more information on how the pooling sizes are computed, please see [link coming soon]().\n\n    The blur is accelerated by generating and sampling from MIP maps of the input image.\n\n    This class caches the foveation information. This means that if it is run repeatedly with the same foveation parameters, gaze location and image size (e.g. in an optimisation loop) it won't recalculate the pooling maps.\n\n    If you are repeatedly applying blur to images of different sizes (e.g. a pyramid) for best performance use one instance of this class per image size.\n\n    \"\"\"\n\n    def __init__(self):\n        self.lod_map = None\n        self.equi = None\n\n    def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode=\"quadratic\", equi=False):\n        \"\"\"\n        Apply the radially varying blur to an image.\n\n        Parameters\n        ----------\n\n        image                   : torch.tensor\n                                    The image to blur, in NCHW format.\n        alpha                   : float\n                                    parameter controlling foveation - larger values mean bigger pooling regions.\n        real_image_width        : float \n                                    The real width of the image as displayed to the user.\n                                    Units don't matter as long as they are the same as for real_viewing_distance.\n                                    Ignored in equirectangular mode (equi==True)\n        real_viewing_distance   : float \n                                    The real distance of the observer's eyes to the image plane.\n                                    Units don't matter as long as they are the same as for real_image_width.\n                                    Ignored in equirectangular mode (equi==True)\n        centre                  : tuple of floats\n                                    The centre of the radially varying blur (the gaze location).\n                                    Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n        mode                    : str \n                                    Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                    as you move away from the fovea. We got best results with \"quadratic\".\n        equi                    : bool\n                                    If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                    The centre argument is instead interpreted as gaze angles, and should be in the range\n                                    [-pi,pi]x[-pi/2,pi]\n\n        Returns\n        -------\n\n        output                  : torch.tensor\n                                    The blurred image\n        \"\"\"\n        size = (image.size(-2), image.size(-1))\n\n        # LOD map caching\n        if self.lod_map is None or\\\n                self.size != size or\\\n                self.n_channels != image.size(1) or\\\n                self.alpha != alpha or\\\n                self.real_image_width != real_image_width or\\\n                self.real_viewing_distance != real_viewing_distance or\\\n                self.centre != centre or\\\n                self.mode != mode or\\\n                self.equi != equi:\n            if not equi:\n                self.lod_map = make_pooling_size_map_lod(\n                    centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)\n            else:\n                self.lod_map = make_equi_pooling_size_map_lod(\n                    centre, (image.size(-2), image.size(-1)), alpha, mode)\n            self.size = size\n            self.n_channels = image.size(1)\n            self.alpha = alpha\n            self.real_image_width = real_image_width\n            self.real_viewing_distance = real_viewing_distance\n            self.centre = centre\n            self.lod_map = self.lod_map.to(image.device)\n            self.lod_fraction = torch.fmod(self.lod_map, 1.0)\n            self.lod_fraction = self.lod_fraction[None, None, ...].repeat(\n                1, image.size(1), 1, 1)\n            self.mode = mode\n            self.equi = equi\n\n        if self.lod_map.device != image.device:\n            self.lod_map = self.lod_map.to(image.device)\n        if self.lod_fraction.device != image.device:\n            self.lod_fraction = self.lod_fraction.to(image.device)\n\n        mipmap = [image]\n        while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:\n            mipmap.append(torch.nn.functional.interpolate(\n                mipmap[-1], scale_factor=0.5, mode=\"area\", recompute_scale_factor=False))\n        if mipmap[-1].size(-1) == 2:\n            final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]\n            mipmap.append(final_mip)\n        if mipmap[-1].size(-2) == 2:\n            final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]\n            mipmap.append(final_mip)\n\n        for l in range(len(mipmap)):\n            if l == len(mipmap)-1:\n                mipmap[l] = mipmap[l] * \\\n                    torch.ones(image.size(), device=image.device)\n            else:\n                for l2 in range(l-1, -1, -1):\n                    mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(\n                        image.size(-2), image.size(-1)), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n\n        output = torch.zeros(image.size(), device=image.device)\n        for l in range(len(mipmap)):\n            if l == 0:\n                mask = self.lod_map < (l+1)\n            elif l == len(mipmap)-1:\n                mask = self.lod_map >= l\n            else:\n                mask = torch.logical_and(\n                    self.lod_map >= l, self.lod_map < (l+1))\n\n            if l == len(mipmap)-1:\n                blended_levels = mipmap[l]\n            else:\n                blended_levels = (1 - self.lod_fraction) * \\\n                    mipmap[l] + self.lod_fraction*mipmap[l+1]\n            mask = mask[None, None, ...]\n            mask = mask.repeat(1, image.size(1), 1, 1)\n            output[mask] = blended_levels[mask]\n\n        return output\n
"},{"location":"odak/learn_perception/#odak.learn.perception.radially_varying_blur.RadiallyVaryingBlur.blur","title":"blur(image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode='quadratic', equi=False)","text":"

Apply the radially varying blur to an image.

Parameters:

  • image \u2013
                        The image to blur, in NCHW format.\n
  • alpha \u2013
                        parameter controlling foveation - larger values mean bigger pooling regions.\n
  • real_image_width \u2013
                        The real width of the image as displayed to the user.\n                    Units don't matter as long as they are the same as for real_viewing_distance.\n                    Ignored in equirectangular mode (equi==True)\n
  • real_viewing_distance \u2013
                        The real distance of the observer's eyes to the image plane.\n                    Units don't matter as long as they are the same as for real_image_width.\n                    Ignored in equirectangular mode (equi==True)\n
  • centre \u2013
                        The centre of the radially varying blur (the gaze location).\n                    Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                    In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n
  • mode \u2013
                        Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                    as you move away from the fovea. We got best results with \"quadratic\".\n
  • equi \u2013
                        If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                    format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                    The centre argument is instead interpreted as gaze angles, and should be in the range\n                    [-pi,pi]x[-pi/2,pi]\n

Returns:

  • output ( tensor ) \u2013

    The blurred image

Source code in odak/learn/perception/radially_varying_blur.py
def blur(self, image, alpha=0.2, real_image_width=0.2, real_viewing_distance=0.7, centre=None, mode=\"quadratic\", equi=False):\n    \"\"\"\n    Apply the radially varying blur to an image.\n\n    Parameters\n    ----------\n\n    image                   : torch.tensor\n                                The image to blur, in NCHW format.\n    alpha                   : float\n                                parameter controlling foveation - larger values mean bigger pooling regions.\n    real_image_width        : float \n                                The real width of the image as displayed to the user.\n                                Units don't matter as long as they are the same as for real_viewing_distance.\n                                Ignored in equirectangular mode (equi==True)\n    real_viewing_distance   : float \n                                The real distance of the observer's eyes to the image plane.\n                                Units don't matter as long as they are the same as for real_image_width.\n                                Ignored in equirectangular mode (equi==True)\n    centre                  : tuple of floats\n                                The centre of the radially varying blur (the gaze location).\n                                Should be a tuple of floats containing normalised image coordinates in range [0,1]\n                                In equirectangular mode this should be yaw & pitch angles in [-pi,pi]x[-pi/2,pi/2]\n    mode                    : str \n                                Foveation mode, either \"quadratic\" or \"linear\". Controls how pooling regions grow\n                                as you move away from the fovea. We got best results with \"quadratic\".\n    equi                    : bool\n                                If true, run the blur function in equirectangular mode. The input is assumed to be an equirectangular\n                                format 360 image. The settings real_image_width and real_viewing distance are ignored.\n                                The centre argument is instead interpreted as gaze angles, and should be in the range\n                                [-pi,pi]x[-pi/2,pi]\n\n    Returns\n    -------\n\n    output                  : torch.tensor\n                                The blurred image\n    \"\"\"\n    size = (image.size(-2), image.size(-1))\n\n    # LOD map caching\n    if self.lod_map is None or\\\n            self.size != size or\\\n            self.n_channels != image.size(1) or\\\n            self.alpha != alpha or\\\n            self.real_image_width != real_image_width or\\\n            self.real_viewing_distance != real_viewing_distance or\\\n            self.centre != centre or\\\n            self.mode != mode or\\\n            self.equi != equi:\n        if not equi:\n            self.lod_map = make_pooling_size_map_lod(\n                centre, (image.size(-2), image.size(-1)), alpha, real_image_width, real_viewing_distance, mode)\n        else:\n            self.lod_map = make_equi_pooling_size_map_lod(\n                centre, (image.size(-2), image.size(-1)), alpha, mode)\n        self.size = size\n        self.n_channels = image.size(1)\n        self.alpha = alpha\n        self.real_image_width = real_image_width\n        self.real_viewing_distance = real_viewing_distance\n        self.centre = centre\n        self.lod_map = self.lod_map.to(image.device)\n        self.lod_fraction = torch.fmod(self.lod_map, 1.0)\n        self.lod_fraction = self.lod_fraction[None, None, ...].repeat(\n            1, image.size(1), 1, 1)\n        self.mode = mode\n        self.equi = equi\n\n    if self.lod_map.device != image.device:\n        self.lod_map = self.lod_map.to(image.device)\n    if self.lod_fraction.device != image.device:\n        self.lod_fraction = self.lod_fraction.to(image.device)\n\n    mipmap = [image]\n    while mipmap[-1].size(-1) > 1 and mipmap[-1].size(-2) > 1:\n        mipmap.append(torch.nn.functional.interpolate(\n            mipmap[-1], scale_factor=0.5, mode=\"area\", recompute_scale_factor=False))\n    if mipmap[-1].size(-1) == 2:\n        final_mip = torch.mean(mipmap[-1], axis=-1)[..., None]\n        mipmap.append(final_mip)\n    if mipmap[-1].size(-2) == 2:\n        final_mip = torch.mean(mipmap[-2], axis=-2)[..., None, :]\n        mipmap.append(final_mip)\n\n    for l in range(len(mipmap)):\n        if l == len(mipmap)-1:\n            mipmap[l] = mipmap[l] * \\\n                torch.ones(image.size(), device=image.device)\n        else:\n            for l2 in range(l-1, -1, -1):\n                mipmap[l] = torch.nn.functional.interpolate(mipmap[l], size=(\n                    image.size(-2), image.size(-1)), mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n\n    output = torch.zeros(image.size(), device=image.device)\n    for l in range(len(mipmap)):\n        if l == 0:\n            mask = self.lod_map < (l+1)\n        elif l == len(mipmap)-1:\n            mask = self.lod_map >= l\n        else:\n            mask = torch.logical_and(\n                self.lod_map >= l, self.lod_map < (l+1))\n\n        if l == len(mipmap)-1:\n            blended_levels = mipmap[l]\n        else:\n            blended_levels = (1 - self.lod_fraction) * \\\n                mipmap[l] + self.lod_fraction*mipmap[l+1]\n        mask = mask[None, None, ...]\n        mask = mask.repeat(1, image.size(1), 1, 1)\n        output[mask] = blended_levels[mask]\n\n    return output\n
"},{"location":"odak/learn_perception/#odak.learn.perception.spatial_steerable_pyramid.SpatialSteerablePyramid","title":"SpatialSteerablePyramid","text":"

This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution) as opposed to multiplication in the Fourier domain. This has a number of optimisations over previous implementations that increase efficiency, but introduce some reconstruction error.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
class SpatialSteerablePyramid():\n    \"\"\"\n    This implements a real-valued steerable pyramid where the filtering is carried out spatially (using convolution)\n    as opposed to multiplication in the Fourier domain.\n    This has a number of optimisations over previous implementations that increase efficiency, but introduce some\n    reconstruction error.\n    \"\"\"\n\n\n    def __init__(self, use_bilinear_downup=True, n_channels=1,\n                 filter_size=9, n_orientations=6, filter_type=\"full\",\n                 device=torch.device('cpu')):\n        \"\"\"\n        Parameters\n        ----------\n\n        use_bilinear_downup     : bool\n                                    This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                                    of applying a large lowpass kernel and sampling even rows/columns\n        n_channels              : int\n                                    Number of channels in the input images (e.g. 3 for RGB input)\n        filter_size             : int\n                                    Desired size of filters (e.g. 3 will use 3x3 filters).\n        n_orientations          : int\n                                    Number of oriented bands in each level of the pyramid.\n        filter_type             : str\n                                    This can be used to select smaller filters than the original ones if desired.\n                                    full: Original filter sizes\n                                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n        device                  : torch.device\n                                    torch device the input images will be supplied from.\n        \"\"\"\n        self.use_bilinear_downup = use_bilinear_downup\n        self.device = device\n\n        filters = get_steerable_pyramid_filters(\n            filter_size, n_orientations, filter_type)\n\n        def make_pad(filter):\n            filter_size = filter.size(-1)\n            pad_amt = (filter_size-1) // 2\n            return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))\n\n        if not self.use_bilinear_downup:\n            self.filt_l = filters[\"l\"].to(device)\n            self.pad_l = make_pad(self.filt_l)\n        self.filt_l0 = filters[\"l0\"].to(device)\n        self.pad_l0 = make_pad(self.filt_l0)\n        self.filt_h0 = filters[\"h0\"].to(device)\n        self.pad_h0 = make_pad(self.filt_h0)\n        for b in range(len(filters[\"b\"])):\n            filters[\"b\"][b] = filters[\"b\"][b].to(device)\n        self.band_filters = filters[\"b\"]\n        self.pad_b = make_pad(self.band_filters[0])\n\n        if n_channels != 1:\n            def add_channels_to_filter(filter):\n                padded = torch.zeros(n_channels, n_channels, filter.size()[\n                                     2], filter.size()[3]).to(device)\n                for channel in range(n_channels):\n                    padded[channel, channel, :, :] = filter\n                return padded\n            self.filt_h0 = add_channels_to_filter(self.filt_h0)\n            for b in range(len(self.band_filters)):\n                self.band_filters[b] = add_channels_to_filter(\n                    self.band_filters[b])\n            self.filt_l0 = add_channels_to_filter(self.filt_l0)\n            if not self.use_bilinear_downup:\n                self.filt_l = add_channels_to_filter(self.filt_l)\n\n    def construct_pyramid(self, image, n_levels, multiple_highpass=False):\n        \"\"\"\n        Constructs and returns a steerable pyramid for the provided image.\n\n        Parameters\n        ----------\n\n        image               : torch.tensor\n                                The input image, in NCHW format. The number of channels C should match num_channels\n                                when the pyramid maker was created.\n        n_levels            : int\n                                Number of levels in the constructed steerable pyramid.\n        multiple_highpass   : bool\n                                If true, computes a highpass for each level of the pyramid.\n                                These extra levels are redundant (not used for reconstruction).\n\n        Returns\n        -------\n\n        pyramid             : list of dicts of torch.tensor\n                                The computed steerable pyramid.\n                                Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.\n                                Each level is stored as a dict, with the following keys:\n                                \"h\" Highpass residual\n                                \"l\" Lowpass residual\n                                \"b\" Oriented bands (a list of torch.tensor)\n        \"\"\"\n        pyramid = []\n\n        # Make level 0, containing highpass, lowpass and the bands\n        level0 = {}\n        level0['h'] = torch.nn.functional.conv2d(\n            self.pad_h0(image), self.filt_h0)\n        lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n        level0['l'] = lowpass.clone()\n        bands = []\n        for filt_b in self.band_filters:\n            bands.append(torch.nn.functional.conv2d(\n                self.pad_b(lowpass), filt_b))\n        level0['b'] = bands\n        pyramid.append(level0)\n\n        # Make intermediate levels\n        for l in range(n_levels-2):\n            level = {}\n            if self.use_bilinear_downup:\n                lowpass = torch.nn.functional.interpolate(\n                    lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n            else:\n                lowpass = torch.nn.functional.conv2d(\n                    self.pad_l(lowpass), self.filt_l)\n                lowpass = lowpass[:, :, ::2, ::2]\n            level['l'] = lowpass.clone()\n            bands = []\n            for filt_b in self.band_filters:\n                bands.append(torch.nn.functional.conv2d(\n                    self.pad_b(lowpass), filt_b))\n            level['b'] = bands\n            if multiple_highpass:\n                level['h'] = torch.nn.functional.conv2d(\n                    self.pad_h0(lowpass), self.filt_h0)\n            pyramid.append(level)\n\n        # Make final level (lowpass residual)\n        level = {}\n        if self.use_bilinear_downup:\n            lowpass = torch.nn.functional.interpolate(\n                lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n        else:\n            lowpass = torch.nn.functional.conv2d(\n                self.pad_l(lowpass), self.filt_l)\n            lowpass = lowpass[:, :, ::2, ::2]\n        level['l'] = lowpass\n        pyramid.append(level)\n\n        return pyramid\n\n    def reconstruct_from_pyramid(self, pyramid):\n        \"\"\"\n        Reconstructs an input image from a steerable pyramid.\n\n        Parameters\n        ----------\n\n        pyramid : list of dicts of torch.tensor\n                    The steerable pyramid.\n                    Should be in the same format as output by construct_steerable_pyramid().\n                    The number of channels should match num_channels when the pyramid maker was created.\n\n        Returns\n        -------\n\n        image   : torch.tensor\n                    The reconstructed image, in NCHW format.         \n        \"\"\"\n        def upsample(image, size):\n            if self.use_bilinear_downup:\n                return torch.nn.functional.interpolate(image, size=size, mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n            else:\n                zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[\n                                    2]*2, image.size()[3]*2)).to(self.device)\n                zeros[:, :, ::2, ::2] = image\n                zeros = torch.nn.functional.conv2d(\n                    self.pad_l(zeros), self.filt_l)\n                return zeros\n\n        image = pyramid[-1]['l']\n        for level in reversed(pyramid[:-1]):\n            image = upsample(image, level['b'][0].size()[2:])\n            for b in range(len(level['b'])):\n                b_filtered = torch.nn.functional.conv2d(\n                    self.pad_b(level['b'][b]), -self.band_filters[b])\n                image += b_filtered\n\n        image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n        image += torch.nn.functional.conv2d(\n            self.pad_h0(pyramid[0]['h']), self.filt_h0)\n\n        return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.spatial_steerable_pyramid.SpatialSteerablePyramid.__init__","title":"__init__(use_bilinear_downup=True, n_channels=1, filter_size=9, n_orientations=6, filter_type='full', device=torch.device('cpu'))","text":"

Parameters:

  • use_bilinear_downup \u2013
                        This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                    of applying a large lowpass kernel and sampling even rows/columns\n
  • n_channels \u2013
                        Number of channels in the input images (e.g. 3 for RGB input)\n
  • filter_size \u2013
                        Desired size of filters (e.g. 3 will use 3x3 filters).\n
  • n_orientations \u2013
                        Number of oriented bands in each level of the pyramid.\n
  • filter_type \u2013
                        This can be used to select smaller filters than the original ones if desired.\n                    full: Original filter sizes\n                    cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                    trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n
  • device \u2013
                        torch device the input images will be supplied from.\n
Source code in odak/learn/perception/spatial_steerable_pyramid.py
def __init__(self, use_bilinear_downup=True, n_channels=1,\n             filter_size=9, n_orientations=6, filter_type=\"full\",\n             device=torch.device('cpu')):\n    \"\"\"\n    Parameters\n    ----------\n\n    use_bilinear_downup     : bool\n                                This uses bilinear filtering when upsampling/downsampling, rather than the original approach\n                                of applying a large lowpass kernel and sampling even rows/columns\n    n_channels              : int\n                                Number of channels in the input images (e.g. 3 for RGB input)\n    filter_size             : int\n                                Desired size of filters (e.g. 3 will use 3x3 filters).\n    n_orientations          : int\n                                Number of oriented bands in each level of the pyramid.\n    filter_type             : str\n                                This can be used to select smaller filters than the original ones if desired.\n                                full: Original filter sizes\n                                cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                                trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n    device                  : torch.device\n                                torch device the input images will be supplied from.\n    \"\"\"\n    self.use_bilinear_downup = use_bilinear_downup\n    self.device = device\n\n    filters = get_steerable_pyramid_filters(\n        filter_size, n_orientations, filter_type)\n\n    def make_pad(filter):\n        filter_size = filter.size(-1)\n        pad_amt = (filter_size-1) // 2\n        return torch.nn.ReflectionPad2d((pad_amt, pad_amt, pad_amt, pad_amt))\n\n    if not self.use_bilinear_downup:\n        self.filt_l = filters[\"l\"].to(device)\n        self.pad_l = make_pad(self.filt_l)\n    self.filt_l0 = filters[\"l0\"].to(device)\n    self.pad_l0 = make_pad(self.filt_l0)\n    self.filt_h0 = filters[\"h0\"].to(device)\n    self.pad_h0 = make_pad(self.filt_h0)\n    for b in range(len(filters[\"b\"])):\n        filters[\"b\"][b] = filters[\"b\"][b].to(device)\n    self.band_filters = filters[\"b\"]\n    self.pad_b = make_pad(self.band_filters[0])\n\n    if n_channels != 1:\n        def add_channels_to_filter(filter):\n            padded = torch.zeros(n_channels, n_channels, filter.size()[\n                                 2], filter.size()[3]).to(device)\n            for channel in range(n_channels):\n                padded[channel, channel, :, :] = filter\n            return padded\n        self.filt_h0 = add_channels_to_filter(self.filt_h0)\n        for b in range(len(self.band_filters)):\n            self.band_filters[b] = add_channels_to_filter(\n                self.band_filters[b])\n        self.filt_l0 = add_channels_to_filter(self.filt_l0)\n        if not self.use_bilinear_downup:\n            self.filt_l = add_channels_to_filter(self.filt_l)\n
"},{"location":"odak/learn_perception/#odak.learn.perception.spatial_steerable_pyramid.SpatialSteerablePyramid.construct_pyramid","title":"construct_pyramid(image, n_levels, multiple_highpass=False)","text":"

Constructs and returns a steerable pyramid for the provided image.

Parameters:

  • image \u2013
                    The input image, in NCHW format. The number of channels C should match num_channels\n                when the pyramid maker was created.\n
  • n_levels \u2013
                    Number of levels in the constructed steerable pyramid.\n
  • multiple_highpass \u2013
                    If true, computes a highpass for each level of the pyramid.\n                These extra levels are redundant (not used for reconstruction).\n

Returns:

  • pyramid ( list of dicts of torch.tensor ) \u2013

    The computed steerable pyramid. Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels. Each level is stored as a dict, with the following keys: \"h\" Highpass residual \"l\" Lowpass residual \"b\" Oriented bands (a list of torch.tensor)

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def construct_pyramid(self, image, n_levels, multiple_highpass=False):\n    \"\"\"\n    Constructs and returns a steerable pyramid for the provided image.\n\n    Parameters\n    ----------\n\n    image               : torch.tensor\n                            The input image, in NCHW format. The number of channels C should match num_channels\n                            when the pyramid maker was created.\n    n_levels            : int\n                            Number of levels in the constructed steerable pyramid.\n    multiple_highpass   : bool\n                            If true, computes a highpass for each level of the pyramid.\n                            These extra levels are redundant (not used for reconstruction).\n\n    Returns\n    -------\n\n    pyramid             : list of dicts of torch.tensor\n                            The computed steerable pyramid.\n                            Each level is an entry in a list. The pyramid is ordered from largest levels to smallest levels.\n                            Each level is stored as a dict, with the following keys:\n                            \"h\" Highpass residual\n                            \"l\" Lowpass residual\n                            \"b\" Oriented bands (a list of torch.tensor)\n    \"\"\"\n    pyramid = []\n\n    # Make level 0, containing highpass, lowpass and the bands\n    level0 = {}\n    level0['h'] = torch.nn.functional.conv2d(\n        self.pad_h0(image), self.filt_h0)\n    lowpass = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n    level0['l'] = lowpass.clone()\n    bands = []\n    for filt_b in self.band_filters:\n        bands.append(torch.nn.functional.conv2d(\n            self.pad_b(lowpass), filt_b))\n    level0['b'] = bands\n    pyramid.append(level0)\n\n    # Make intermediate levels\n    for l in range(n_levels-2):\n        level = {}\n        if self.use_bilinear_downup:\n            lowpass = torch.nn.functional.interpolate(\n                lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n        else:\n            lowpass = torch.nn.functional.conv2d(\n                self.pad_l(lowpass), self.filt_l)\n            lowpass = lowpass[:, :, ::2, ::2]\n        level['l'] = lowpass.clone()\n        bands = []\n        for filt_b in self.band_filters:\n            bands.append(torch.nn.functional.conv2d(\n                self.pad_b(lowpass), filt_b))\n        level['b'] = bands\n        if multiple_highpass:\n            level['h'] = torch.nn.functional.conv2d(\n                self.pad_h0(lowpass), self.filt_h0)\n        pyramid.append(level)\n\n    # Make final level (lowpass residual)\n    level = {}\n    if self.use_bilinear_downup:\n        lowpass = torch.nn.functional.interpolate(\n            lowpass, scale_factor=0.5, mode=\"area\", recompute_scale_factor=False)\n    else:\n        lowpass = torch.nn.functional.conv2d(\n            self.pad_l(lowpass), self.filt_l)\n        lowpass = lowpass[:, :, ::2, ::2]\n    level['l'] = lowpass\n    pyramid.append(level)\n\n    return pyramid\n
"},{"location":"odak/learn_perception/#odak.learn.perception.spatial_steerable_pyramid.SpatialSteerablePyramid.reconstruct_from_pyramid","title":"reconstruct_from_pyramid(pyramid)","text":"

Reconstructs an input image from a steerable pyramid.

Parameters:

  • pyramid (list of dicts of torch.tensor) \u2013
        The steerable pyramid.\n    Should be in the same format as output by construct_steerable_pyramid().\n    The number of channels should match num_channels when the pyramid maker was created.\n

Returns:

  • image ( tensor ) \u2013

    The reconstructed image, in NCHW format.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def reconstruct_from_pyramid(self, pyramid):\n    \"\"\"\n    Reconstructs an input image from a steerable pyramid.\n\n    Parameters\n    ----------\n\n    pyramid : list of dicts of torch.tensor\n                The steerable pyramid.\n                Should be in the same format as output by construct_steerable_pyramid().\n                The number of channels should match num_channels when the pyramid maker was created.\n\n    Returns\n    -------\n\n    image   : torch.tensor\n                The reconstructed image, in NCHW format.         \n    \"\"\"\n    def upsample(image, size):\n        if self.use_bilinear_downup:\n            return torch.nn.functional.interpolate(image, size=size, mode=\"bilinear\", align_corners=False, recompute_scale_factor=False)\n        else:\n            zeros = torch.zeros((image.size()[0], image.size()[1], image.size()[\n                                2]*2, image.size()[3]*2)).to(self.device)\n            zeros[:, :, ::2, ::2] = image\n            zeros = torch.nn.functional.conv2d(\n                self.pad_l(zeros), self.filt_l)\n            return zeros\n\n    image = pyramid[-1]['l']\n    for level in reversed(pyramid[:-1]):\n        image = upsample(image, level['b'][0].size()[2:])\n        for b in range(len(level['b'])):\n            b_filtered = torch.nn.functional.conv2d(\n                self.pad_b(level['b'][b]), -self.band_filters[b])\n            image += b_filtered\n\n    image = torch.nn.functional.conv2d(self.pad_l0(image), self.filt_l0)\n    image += torch.nn.functional.conv2d(\n        self.pad_h0(pyramid[0]['h']), self.filt_h0)\n\n    return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.spatial_steerable_pyramid.pad_image_for_pyramid","title":"pad_image_for_pyramid(image, n_pyramid_levels)","text":"

Pads an image to the extent necessary to compute a steerable pyramid of the input image. This involves padding so both height and width are divisible by 2**n_pyramid_levels. Uses reflection padding.

Parameters:

  • image \u2013

    Image to pad, in NCHW format

  • n_pyramid_levels \u2013

    Number of levels in the pyramid you plan to construct.

Source code in odak/learn/perception/spatial_steerable_pyramid.py
def pad_image_for_pyramid(image, n_pyramid_levels):\n    \"\"\"\n    Pads an image to the extent necessary to compute a steerable pyramid of the input image.\n    This involves padding so both height and width are divisible by 2**n_pyramid_levels.\n    Uses reflection padding.\n\n    Parameters\n    ----------\n\n    image: torch.tensor\n        Image to pad, in NCHW format\n    n_pyramid_levels: int\n        Number of levels in the pyramid you plan to construct.\n    \"\"\"\n    min_divisor = 2 ** n_pyramid_levels\n    height = image.size(2)\n    width = image.size(3)\n    required_height = math.ceil(height / min_divisor) * min_divisor\n    required_width = math.ceil(width / min_divisor) * min_divisor\n    if required_height > height or required_width > width:\n        # We need to pad!\n        pad = torch.nn.ReflectionPad2d(\n            (0, 0, required_height-height, required_width-width))\n        return pad(image)\n    return image\n
"},{"location":"odak/learn_perception/#odak.learn.perception.steerable_pyramid_filters.crop_steerable_pyramid_filters","title":"crop_steerable_pyramid_filters(filters, size)","text":"

Given original 9x9 NYU filters, this crops them to the desired size. The size must be an odd number >= 3 Note this only crops the h0, l0 and band filters (not the l downsampling filter)

Parameters:

  • filters \u2013
            Filters to crop (should in format used by get_steerable_pyramid_filters.)\n
  • size \u2013
            Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.\n

Returns:

  • filters ( dict of torch.tensor ) \u2013

    The cropped filters.

Source code in odak/learn/perception/steerable_pyramid_filters.py
def crop_steerable_pyramid_filters(filters, size):\n    \"\"\"\n    Given original 9x9 NYU filters, this crops them to the desired size.\n    The size must be an odd number >= 3\n    Note this only crops the h0, l0 and band filters (not the l downsampling filter)\n\n    Parameters\n    ----------\n    filters     : dict of torch.tensor\n                    Filters to crop (should in format used by get_steerable_pyramid_filters.)\n    size        : int\n                    Size to crop to. For example, an input of 3 will crop the filters to a size of 3x3.\n\n    Returns\n    -------\n    filters     : dict of torch.tensor\n                    The cropped filters.\n    \"\"\"\n    assert(size >= 3)\n    assert(size % 2 == 1)\n    r = (size-1) // 2\n\n    def crop_filter(filter, r, normalise=True):\n        r2 = (filter.size(-1)-1)//2\n        filter = filter[:, :, r2-r:r2+r+1, r2-r:r2+r+1]\n        if normalise:\n            filter -= torch.sum(filter)\n        return filter\n\n    filters[\"h0\"] = crop_filter(filters[\"h0\"], r, normalise=False)\n    sum_l = torch.sum(filters[\"l\"])\n    filters[\"l\"] = crop_filter(filters[\"l\"], 6, normalise=False)\n    filters[\"l\"] *= sum_l / torch.sum(filters[\"l\"])\n    sum_l0 = torch.sum(filters[\"l0\"])\n    filters[\"l0\"] = crop_filter(filters[\"l0\"], 2, normalise=False)\n    filters[\"l0\"] *= sum_l0 / torch.sum(filters[\"l0\"])\n    for b in range(len(filters[\"b\"])):\n        filters[\"b\"][b] = crop_filter(filters[\"b\"][b], r, normalise=True)\n    return filters\n
"},{"location":"odak/learn_perception/#odak.learn.perception.steerable_pyramid_filters.get_steerable_pyramid_filters","title":"get_steerable_pyramid_filters(size, n_orientations, filter_type)","text":"

This returns filters for a real-valued steerable pyramid.

Parameters:

  • size \u2013
                Width of the filters (e.g. 3 will return 3x3 filters)\n
  • n_orientations \u2013
                Number of oriented band filters\n
  • filter_type \u2013
                This can be used to select between the original NYU filters and cropped or trained alternatives.\n            full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py\n            cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n            trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n

Returns:

  • filters ( dict of torch.tensor ) \u2013

    The steerable pyramid filters. Returned as a dict with the following keys: \"l\" The lowpass downsampling filter \"l0\" The lowpass residual filter \"h0\" The highpass residual filter \"b\" The band filters (a list of torch.tensor filters, one for each orientation).

Source code in odak/learn/perception/steerable_pyramid_filters.py
def get_steerable_pyramid_filters(size, n_orientations, filter_type):\n    \"\"\"\n    This returns filters for a real-valued steerable pyramid.\n\n    Parameters\n    ----------\n\n    size            : int\n                        Width of the filters (e.g. 3 will return 3x3 filters)\n    n_orientations  : int\n                        Number of oriented band filters\n    filter_type     :  str\n                        This can be used to select between the original NYU filters and cropped or trained alternatives.\n                        full: Original NYU filters from https://github.com/LabForComputationalVision/pyrtools/blob/master/pyrtools/pyramids/filters.py\n                        cropped: Some filters are cut back in size by extracting the centre and scaling as appropriate.\n                        trained: Same as reduced, but the oriented kernels are replaced by learned 5x5 kernels.\n\n    Returns\n    -------\n    filters         : dict of torch.tensor\n                        The steerable pyramid filters. Returned as a dict with the following keys:\n                        \"l\" The lowpass downsampling filter\n                        \"l0\" The lowpass residual filter\n                        \"h0\" The highpass residual filter\n                        \"b\" The band filters (a list of torch.tensor filters, one for each orientation).\n    \"\"\"\n\n    if filter_type != \"full\" and filter_type != \"cropped\" and filter_type != \"trained\":\n        raise Exception(\n            \"Unknown filter type %s! Only filter types are full, cropped or trained.\" % filter_type)\n\n    filters = {}\n    if n_orientations == 1:\n        filters[\"l\"] = torch.tensor([\n            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -\n                1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04],\n            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -\n                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],\n            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -\n                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],\n            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,\n                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],\n            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,\n                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],\n            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,\n                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],\n            [-1.871920e-03, -6.948900e-03, -3.781600e-03, 2.449600e-02, 7.624674e-02, 1.348999e-01, 1.576508e-01,\n                1.348999e-01, 7.624674e-02, 2.449600e-02, -3.781600e-03, -6.948900e-03, -1.871920e-03],\n            [-1.031640e-03, -7.006740e-03, -5.236180e-03, 1.722078e-02, 6.306262e-02, 1.116388e-01, 1.348999e-01,\n                1.116388e-01, 6.306262e-02, 1.722078e-02, -5.236180e-03, -7.006740e-03, -1.031640e-03],\n            [-1.862800e-04, -4.596420e-03, -6.720800e-03, 3.938620e-03, 3.220744e-02, 6.306262e-02, 7.624674e-02,\n                6.306262e-02, 3.220744e-02, 3.938620e-03, -6.720800e-03, -4.596420e-03, -1.862800e-04],\n            [8.741400e-04, -2.449060e-03, -6.401000e-03, -5.260020e-03, 3.938620e-03, 1.722078e-02, 2.449600e-02,\n                1.722078e-02, 3.938620e-03, -5.260020e-03, -6.401000e-03, -2.449060e-03, 8.741400e-04],\n            [-5.686000e-05, -1.903800e-04, -3.059760e-03, -6.401000e-03, -6.720800e-03, -5.236180e-03, -\n                3.781600e-03, -5.236180e-03, -6.720800e-03, -6.401000e-03, -3.059760e-03, -1.903800e-04, -5.686000e-05],\n            [-8.064400e-04, 1.417620e-03, -1.903800e-04, -2.449060e-03, -4.596420e-03, -7.006740e-03, -\n                6.948900e-03, -7.006740e-03, -4.596420e-03, -2.449060e-03, -1.903800e-04, 1.417620e-03, -8.064400e-04],\n            [-2.257000e-04, -8.064400e-04, -5.686000e-05, 8.741400e-04, -1.862800e-04, -1.031640e-03, -1.871920e-03, -1.031640e-03, -1.862800e-04, 8.741400e-04, -5.686000e-05, -8.064400e-04, -2.257000e-04]]\n        ).reshape(1, 1, 13, 13)\n        filters[\"l0\"] = torch.tensor([\n            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -\n                3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04],\n            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -\n                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],\n            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,\n                6.441488e-02, -1.344160e-02, -3.725800e-04],\n            [-3.743860e-03, -7.563200e-03, 1.524935e-01, 3.153017e-01,\n                1.524935e-01, -7.563200e-03, -3.743860e-03],\n            [-3.725800e-04, -1.344160e-02, 6.441488e-02, 1.524935e-01,\n                6.441488e-02, -1.344160e-02, -3.725800e-04],\n            [-1.137100e-04, -6.119520e-03, -1.344160e-02, -\n                7.563200e-03, -1.344160e-02, -6.119520e-03, -1.137100e-04],\n            [-4.514000e-04, -1.137100e-04, -3.725800e-04, -3.743860e-03, -3.725800e-04, -1.137100e-04, -4.514000e-04]]\n        ).reshape(1, 1, 7, 7)\n        filters[\"h0\"] = torch.tensor([\n            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -\n                2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04],\n            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -\n                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],\n            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -\n                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],\n            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -\n                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],\n            [-2.406600e-04, -3.732100e-04, -2.420138e-02, -9.623594e-02,\n                8.554893e-01, -9.623594e-02, -2.420138e-02, -3.732100e-04, -2.406600e-04],\n            [-3.325600e-04, 1.459700e-04, -1.437358e-02, -6.300923e-02, -\n                9.623594e-02, -6.300923e-02, -1.437358e-02, 1.459700e-04, -3.325600e-04],\n            [-3.324900e-04, 4.927100e-04, -1.616650e-03, -1.437358e-02, -\n                2.420138e-02, -1.437358e-02, -1.616650e-03, 4.927100e-04, -3.324900e-04],\n            [-6.068000e-05, 1.263100e-04, 4.927100e-04, 1.459700e-04, -\n                3.732100e-04, 1.459700e-04, 4.927100e-04, 1.263100e-04, -6.068000e-05],\n            [5.997200e-04, -6.068000e-05, -3.324900e-04, -3.325600e-04, -2.406600e-04, -3.325600e-04, -3.324900e-04, -6.068000e-05, 5.997200e-04]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor([\n            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -\n            1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05,\n            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -\n            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,\n            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -\n            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,\n            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -\n            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,\n            -1.009473e-02, -9.091950e-03, -3.487008e-02, -3.158605e-02, 9.464195e-01, -\n            3.158605e-02, -3.487008e-02, -9.091950e-03, -1.009473e-02,\n            -7.889390e-03, -7.623410e-03, -2.435662e-02, -1.730466e-02, -\n            3.158605e-02, -1.730466e-02, -2.435662e-02, -7.623410e-03, -7.889390e-03,\n            -4.942500e-03, -7.272540e-03, -2.129540e-02, -2.435662e-02, -\n            3.487008e-02, -2.435662e-02, -2.129540e-02, -7.272540e-03, -4.942500e-03,\n            -1.738640e-03, -4.625150e-03, -7.272540e-03, -7.623410e-03, -\n            9.091950e-03, -7.623410e-03, -7.272540e-03, -4.625150e-03, -1.738640e-03,\n            -9.066000e-05, -1.738640e-03, -4.942500e-03, -7.889390e-03, -1.009473e-02, -7.889390e-03, -4.942500e-03, -1.738640e-03, -9.066000e-05]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 2:\n        filters[\"l\"] = torch.tensor(\n            [[-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05],\n             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,\n                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],\n             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,\n                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],\n             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -\n                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],\n             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -\n                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],\n             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,\n                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],\n             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,\n                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],\n             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,\n                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],\n             [1.262000e-03, 5.589400e-03, 5.538200e-04, -9.637220e-03, -1.648568e-02, 1.438512e-02, 9.058014e-02, 1.773651e-01, 2.120374e-01,\n                 1.773651e-01, 9.058014e-02, 1.438512e-02, -1.648568e-02, -9.637220e-03, 5.538200e-04, 5.589400e-03, 1.262000e-03],\n             [-4.202000e-04, 4.287280e-03, 2.221220e-03, -9.056920e-03, -1.750818e-02, 3.955660e-03, 6.806584e-02, 1.445500e-01, 1.773651e-01,\n                 1.445500e-01, 6.806584e-02, 3.955660e-03, -1.750818e-02, -9.056920e-03, 2.221220e-03, 4.287280e-03, -4.202000e-04],\n             [-2.516800e-04, 2.889860e-03, 4.112200e-03, -7.431700e-03, -1.884744e-02, -1.109170e-02, 2.190660e-02, 6.806584e-02, 9.058014e-02,\n                 6.806584e-02, 2.190660e-02, -1.109170e-02, -1.884744e-02, -7.431700e-03, 4.112200e-03, 2.889860e-03, -2.516800e-04],\n             [-1.597040e-03, 2.325540e-03, 3.080980e-03, -1.777480e-03, -1.260420e-02, -2.022938e-02, -1.109170e-02, 3.955660e-03, 1.438512e-02,\n                 3.955660e-03, -1.109170e-02, -2.022938e-02, -1.260420e-02, -1.777480e-03, 3.080980e-03, 2.325540e-03, -1.597040e-03],\n             [-8.006400e-04, -1.368800e-04, 3.761360e-03, 3.184680e-03, -3.530640e-03, -1.260420e-02, -1.884744e-02, -1.750818e-02, -\n                 1.648568e-02, -1.750818e-02, -1.884744e-02, -1.260420e-02, -3.530640e-03, 3.184680e-03, 3.761360e-03, -1.368800e-04, -8.006400e-04],\n             [-1.243400e-04, 5.621600e-04, 2.160540e-03, 3.175780e-03, 3.184680e-03, -1.777480e-03, -7.431700e-03, -9.056920e-03, -\n                 9.637220e-03, -9.056920e-03, -7.431700e-03, -1.777480e-03, 3.184680e-03, 3.175780e-03, 2.160540e-03, 5.621600e-04, -1.243400e-04],\n             [-6.771400e-04, -5.814600e-04, 1.460780e-03, 2.160540e-03, 3.761360e-03, 3.080980e-03, 4.112200e-03, 2.221220e-03, 5.538200e-04,\n                 2.221220e-03, 4.112200e-03, 3.080980e-03, 3.761360e-03, 2.160540e-03, 1.460780e-03, -5.814600e-04, -6.771400e-04],\n             [1.207800e-04, 4.460600e-04, -5.814600e-04, 5.621600e-04, -1.368800e-04, 2.325540e-03, 2.889860e-03, 4.287280e-03, 5.589400e-03,\n                 4.287280e-03, 2.889860e-03, 2.325540e-03, -1.368800e-04, 5.621600e-04, -5.814600e-04, 4.460600e-04, 1.207800e-04],\n             [-4.350000e-05, 1.207800e-04, -6.771400e-04, -1.243400e-04, -8.006400e-04, -1.597040e-03, -2.516800e-04, -4.202000e-04, 1.262000e-03, -4.202000e-04, -2.516800e-04, -1.597040e-03, -8.006400e-04, -1.243400e-04, -6.771400e-04, 1.207800e-04, -4.350000e-05]]\n        ).reshape(1, 1, 17, 17)\n        filters[\"l0\"] = torch.tensor(\n            [[-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05],\n             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,\n                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],\n             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -\n                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],\n             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,\n                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],\n             [2.524010e-03, 1.107620e-03, -3.297137e-02, 1.811603e-01, 4.376250e-01,\n                 1.811603e-01, -3.297137e-02, 1.107620e-03, 2.524010e-03],\n             [-5.033700e-04, 8.224420e-03, -3.769487e-02, 4.381320e-02, 1.811603e-01,\n                 4.381320e-02, -3.769487e-02, 8.224420e-03, -5.033700e-04],\n             [-1.601260e-03, 7.522720e-03, -7.061290e-03, -3.769487e-02, -\n                 3.297137e-02, -3.769487e-02, -7.061290e-03, 7.522720e-03, -1.601260e-03],\n             [-1.354280e-03, 2.921580e-03, 7.522720e-03, 8.224420e-03, 1.107620e-03,\n                 8.224420e-03, 7.522720e-03, 2.921580e-03, -1.354280e-03],\n             [-8.701000e-05, -1.354280e-03, -1.601260e-03, -5.033700e-04, 2.524010e-03, -5.033700e-04, -1.601260e-03, -1.354280e-03, -8.701000e-05]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"h0\"] = torch.tensor(\n            [[-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04],\n             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,\n                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],\n             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -\n                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],\n             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -\n                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],\n             [-1.166810e-03, 1.098012e-02, -5.897780e-03, -1.562827e-01,\n                 7.282593e-01, -1.562827e-01, -5.897780e-03, 1.098012e-02, -1.166810e-03],\n             [-8.742600e-04, 9.156420e-03, 1.094866e-02, -7.841370e-02, -\n                 1.562827e-01, -7.841370e-02, 1.094866e-02, 9.156420e-03, -8.742600e-04],\n             [-1.424720e-03, 8.998600e-04, 1.706347e-02, 1.094866e-02, -\n                 5.897780e-03, 1.094866e-02, 1.706347e-02, 8.998600e-04, -1.424720e-03],\n             [-2.424100e-04, -4.317530e-03, 8.998600e-04, 9.156420e-03, 1.098012e-02,\n                 9.156420e-03, 8.998600e-04, -4.317530e-03, -2.424100e-04],\n             [-9.570000e-04, -2.424100e-04, -1.424720e-03, -8.742600e-04, -1.166810e-03, -8.742600e-04, -1.424720e-03, -2.424100e-04, -9.570000e-04]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor(\n            [6.125880e-03, -8.052600e-03, -2.103714e-02, -1.536890e-02, -1.851466e-02, -1.536890e-02, -2.103714e-02, -8.052600e-03, 6.125880e-03,\n             -1.287416e-02, -9.611520e-03, 1.023569e-02, 6.009450e-03, 1.872620e-03, 6.009450e-03, 1.023569e-02, -\n             9.611520e-03, -1.287416e-02,\n             -5.641530e-03, 4.168400e-03, -2.382180e-02, -5.375324e-02, -\n             2.076086e-02, -5.375324e-02, -2.382180e-02, 4.168400e-03, -5.641530e-03,\n             -8.957260e-03, -1.751170e-03, -1.836909e-02, 1.265655e-01, 2.996168e-01, 1.265655e-01, -\n             1.836909e-02, -1.751170e-03, -8.957260e-03,\n             0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00,\n             8.957260e-03, 1.751170e-03, 1.836909e-02, -1.265655e-01, -\n             2.996168e-01, -1.265655e-01, 1.836909e-02, 1.751170e-03, 8.957260e-03,\n             5.641530e-03, -4.168400e-03, 2.382180e-02, 5.375324e-02, 2.076086e-02, 5.375324e-02, 2.382180e-02, -\n             4.168400e-03, 5.641530e-03,\n             1.287416e-02, 9.611520e-03, -1.023569e-02, -6.009450e-03, -\n             1.872620e-03, -6.009450e-03, -1.023569e-02, 9.611520e-03, 1.287416e-02,\n             -6.125880e-03, 8.052600e-03, 2.103714e-02, 1.536890e-02, 1.851466e-02, 1.536890e-02, 2.103714e-02, 8.052600e-03, -6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [-6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03,\n             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -\n             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,\n             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -\n             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,\n             1.536890e-02, -6.009450e-03, 5.375324e-02, -\n             1.265655e-01, 0.000000e+00, 1.265655e-01, -\n             5.375324e-02, 6.009450e-03, -1.536890e-02,\n             1.851466e-02, -1.872620e-03, 2.076086e-02, -\n             2.996168e-01, 0.000000e+00, 2.996168e-01, -\n             2.076086e-02, 1.872620e-03, -1.851466e-02,\n             1.536890e-02, -6.009450e-03, 5.375324e-02, -\n             1.265655e-01, 0.000000e+00, 1.265655e-01, -\n             5.375324e-02, 6.009450e-03, -1.536890e-02,\n             2.103714e-02, -1.023569e-02, 2.382180e-02, 1.836909e-02, 0.000000e+00, -\n             1.836909e-02, -2.382180e-02, 1.023569e-02, -2.103714e-02,\n             8.052600e-03, 9.611520e-03, -4.168400e-03, 1.751170e-03, 0.000000e+00, -\n             1.751170e-03, 4.168400e-03, -9.611520e-03, -8.052600e-03,\n             -6.125880e-03, 1.287416e-02, 5.641530e-03, 8.957260e-03, 0.000000e+00, -8.957260e-03, -5.641530e-03, -1.287416e-02, 6.125880e-03]).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 4:\n        filters[\"l\"] = torch.tensor([\n            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4,\n                1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5],\n            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,\n                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],\n            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,\n                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],\n            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -\n                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],\n            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -\n                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],\n            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,\n                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],\n            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,\n                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],\n            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,\n                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],\n            [1.2619999470E-3, 5.5893999524E-3, 5.5381999118E-4, -9.6372198313E-3, -1.6485679895E-2, 1.4385120012E-2, 9.0580143034E-2, 0.1773651242, 0.2120374441,\n                0.1773651242, 9.0580143034E-2, 1.4385120012E-2, -1.6485679895E-2, -9.6372198313E-3, 5.5381999118E-4, 5.5893999524E-3, 1.2619999470E-3],\n            [-4.2019999819E-4, 4.2872801423E-3, 2.2212199401E-3, -9.0569201857E-3, -1.7508180812E-2, 3.9556599222E-3, 6.8065837026E-2, 0.1445499808, 0.1773651242,\n                0.1445499808, 6.8065837026E-2, 3.9556599222E-3, -1.7508180812E-2, -9.0569201857E-3, 2.2212199401E-3, 4.2872801423E-3, -4.2019999819E-4],\n            [-2.5168000138E-4, 2.8898599558E-3, 4.1121998802E-3, -7.4316998944E-3, -1.8847439438E-2, -1.1091699824E-2, 2.1906599402E-2, 6.8065837026E-2, 9.0580143034E-2,\n                6.8065837026E-2, 2.1906599402E-2, -1.1091699824E-2, -1.8847439438E-2, -7.4316998944E-3, 4.1121998802E-3, 2.8898599558E-3, -2.5168000138E-4],\n            [-1.5970399836E-3, 2.3255399428E-3, 3.0809799209E-3, -1.7774800071E-3, -1.2604200281E-2, -2.0229380578E-2, -1.1091699824E-2, 3.9556599222E-3, 1.4385120012E-2,\n                3.9556599222E-3, -1.1091699824E-2, -2.0229380578E-2, -1.2604200281E-2, -1.7774800071E-3, 3.0809799209E-3, 2.3255399428E-3, -1.5970399836E-3],\n            [-8.0063997302E-4, -1.3688000035E-4, 3.7613599561E-3, 3.1846798956E-3, -3.5306399222E-3, -1.2604200281E-2, -1.8847439438E-2, -1.7508180812E-2, -\n                1.6485679895E-2, -1.7508180812E-2, -1.8847439438E-2, -1.2604200281E-2, -3.5306399222E-3, 3.1846798956E-3, 3.7613599561E-3, -1.3688000035E-4, -8.0063997302E-4],\n            [-1.2434000382E-4, 5.6215998484E-4, 2.1605400834E-3, 3.1757799443E-3, 3.1846798956E-3, -1.7774800071E-3, -7.4316998944E-3, -9.0569201857E-3, -\n                9.6372198313E-3, -9.0569201857E-3, -7.4316998944E-3, -1.7774800071E-3, 3.1846798956E-3, 3.1757799443E-3, 2.1605400834E-3, 5.6215998484E-4, -1.2434000382E-4],\n            [-6.7714002216E-4, -5.8146001538E-4, 1.4607800404E-3, 2.1605400834E-3, 3.7613599561E-3, 3.0809799209E-3, 4.1121998802E-3, 2.2212199401E-3, 5.5381999118E-4,\n                2.2212199401E-3, 4.1121998802E-3, 3.0809799209E-3, 3.7613599561E-3, 2.1605400834E-3, 1.4607800404E-3, -5.8146001538E-4, -6.7714002216E-4],\n            [1.2078000145E-4, 4.4606000301E-4, -5.8146001538E-4, 5.6215998484E-4, -1.3688000035E-4, 2.3255399428E-3, 2.8898599558E-3, 4.2872801423E-3, 5.5893999524E-3,\n                4.2872801423E-3, 2.8898599558E-3, 2.3255399428E-3, -1.3688000035E-4, 5.6215998484E-4, -5.8146001538E-4, 4.4606000301E-4, 1.2078000145E-4],\n            [-4.3500000174E-5, 1.2078000145E-4, -6.7714002216E-4, -1.2434000382E-4, -8.0063997302E-4, -1.5970399836E-3, -2.5168000138E-4, -4.2019999819E-4, 1.2619999470E-3, -4.2019999819E-4, -2.5168000138E-4, -1.5970399836E-3, -8.0063997302E-4, -1.2434000382E-4, -6.7714002216E-4, 1.2078000145E-4, -4.3500000174E-5]]\n        ).reshape(1, 1, 17, 17)\n        filters[\"l0\"] = torch.tensor([\n            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4,\n                2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5],\n            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,\n                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],\n            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -\n                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],\n            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,\n                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],\n            [2.5240099058E-3, 1.1076199589E-3, -3.2971370965E-2, 0.1811603010, 0.4376249909,\n                0.1811603010, -3.2971370965E-2, 1.1076199589E-3, 2.5240099058E-3],\n            [-5.0337001448E-4, 8.2244202495E-3, -3.7694871426E-2, 4.3813198805E-2, 0.1811603010,\n                4.3813198805E-2, -3.7694871426E-2, 8.2244202495E-3, -5.0337001448E-4],\n            [-1.6012600390E-3, 7.5227199122E-3, -7.0612900890E-3, -3.7694871426E-2, -\n                3.2971370965E-2, -3.7694871426E-2, -7.0612900890E-3, 7.5227199122E-3, -1.6012600390E-3],\n            [-1.3542800443E-3, 2.9215801042E-3, 7.5227199122E-3, 8.2244202495E-3, 1.1076199589E-3,\n                8.2244202495E-3, 7.5227199122E-3, 2.9215801042E-3, -1.3542800443E-3],\n            [-8.7009997515E-5, -1.3542800443E-3, -1.6012600390E-3, -5.0337001448E-4, 2.5240099058E-3, -5.0337001448E-4, -1.6012600390E-3, -1.3542800443E-3, -8.7009997515E-5]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"h0\"] = torch.tensor([\n            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3,\n                2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4],\n            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,\n                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],\n            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -\n                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],\n            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -\n                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],\n            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,\n                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],\n            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -\n                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],\n            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -\n                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],\n            [2.0898699295E-3, 1.9755100366E-3, -8.3368999185E-4, -5.1014497876E-3, 7.3032598011E-3, -9.3812197447E-3, -0.1554141641,\n                0.7303866148, -0.1554141641, -9.3812197447E-3, 7.3032598011E-3, -5.1014497876E-3, -8.3368999185E-4, 1.9755100366E-3, 2.0898699295E-3],\n            [2.0687500946E-3, 2.2892199922E-3, 3.3578000148E-4, -4.5751798898E-3, 5.3632198833E-3, 7.8046298586E-3, -7.9501636326E-2, -\n                0.1554141641, -7.9501636326E-2, 7.8046298586E-3, 5.3632198833E-3, -4.5751798898E-3, 3.3578000148E-4, 2.2892199922E-3, 2.0687500946E-3],\n            [1.9235999789E-3, 2.5626500137E-3, 2.1145299543E-3, -3.1295299996E-3, -2.7556000277E-3, 1.3962360099E-2, 7.8046298586E-3, -\n                9.3812197447E-3, 7.8046298586E-3, 1.3962360099E-2, -2.7556000277E-3, -3.1295299996E-3, 2.1145299543E-3, 2.5626500137E-3, 1.9235999789E-3],\n            [1.5450799838E-3, 2.1750701126E-3, 2.2253401112E-3, -4.9274001503E-4, -6.3222697936E-3, -2.7556000277E-3, 5.3632198833E-3, 7.3032598011E-3,\n                5.3632198833E-3, -2.7556000277E-3, -6.3222697936E-3, -4.9274001503E-4, 2.2253401112E-3, 2.1750701126E-3, 1.5450799838E-3],\n            [8.8387000142E-4, 1.5874400269E-3, 1.4050999889E-3, 1.2960999738E-3, -4.9274001503E-4, -3.1295299996E-3, -4.5751798898E-3, -\n                5.1014497876E-3, -4.5751798898E-3, -3.1295299996E-3, -4.9274001503E-4, 1.2960999738E-3, 1.4050999889E-3, 1.5874400269E-3, 8.8387000142E-4],\n            [-3.7829999201E-5, 7.7435001731E-4, 1.1793200392E-3, 1.4050999889E-3, 2.2253401112E-3, 2.1145299543E-3, 3.3578000148E-4, -\n                8.3368999185E-4, 3.3578000148E-4, 2.1145299543E-3, 2.2253401112E-3, 1.4050999889E-3, 1.1793200392E-3, 7.7435001731E-4, -3.7829999201E-5],\n            [-6.2596000498E-4, -3.2734998967E-4, 7.7435001731E-4, 1.5874400269E-3, 2.1750701126E-3, 2.5626500137E-3, 2.2892199922E-3, 1.9755100366E-3,\n                2.2892199922E-3, 2.5626500137E-3, 2.1750701126E-3, 1.5874400269E-3, 7.7435001731E-4, -3.2734998967E-4, -6.2596000498E-4],\n            [-4.0483998600E-4, -6.2596000498E-4, -3.7829999201E-5, 8.8387000142E-4, 1.5450799838E-3, 1.9235999789E-3, 2.0687500946E-3, 2.0898699295E-3, 2.0687500946E-3, 1.9235999789E-3, 1.5450799838E-3, 8.8387000142E-4, -3.7829999201E-5, -6.2596000498E-4, -4.0483998600E-4]]\n        ).reshape(1, 1, 15, 15)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor(\n            [-8.1125000725E-4, 4.4451598078E-3, 1.2316980399E-2, 1.3955879956E-2,  1.4179450460E-2, 1.3955879956E-2, 1.2316980399E-2, 4.4451598078E-3, -8.1125000725E-4,\n             3.9103501476E-3, 4.4565401040E-3, -5.8724298142E-3, -2.8760801069E-3, 8.5267601535E-3, -\n             2.8760801069E-3, -5.8724298142E-3, 4.4565401040E-3, 3.9103501476E-3,\n             1.3462699717E-3, -3.7740699481E-3, 8.2581602037E-3, 3.9442278445E-2, 5.3605638444E-2, 3.9442278445E-2, 8.2581602037E-3, -\n             3.7740699481E-3, 1.3462699717E-3,\n             7.4700999539E-4, -3.6522001028E-4, -2.2522680461E-2, -0.1105690673, -\n             0.1768419296, -0.1105690673, -2.2522680461E-2, -3.6522001028E-4, 7.4700999539E-4,\n             0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000,\n             -7.4700999539E-4, 3.6522001028E-4, 2.2522680461E-2, 0.1105690673, 0.1768419296, 0.1105690673, 2.2522680461E-2, 3.6522001028E-4, -7.4700999539E-4,\n             -1.3462699717E-3, 3.7740699481E-3, -8.2581602037E-3, -3.9442278445E-2, -\n             5.3605638444E-2, -3.9442278445E-2, -\n             8.2581602037E-3, 3.7740699481E-3, -1.3462699717E-3,\n             -3.9103501476E-3, -4.4565401040E-3, 5.8724298142E-3, 2.8760801069E-3, -\n             8.5267601535E-3, 2.8760801069E-3, 5.8724298142E-3, -\n             4.4565401040E-3, -3.9103501476E-3,\n             8.1125000725E-4, -4.4451598078E-3, -1.2316980399E-2, -1.3955879956E-2, -1.4179450460E-2, -1.3955879956E-2, -1.2316980399E-2, -4.4451598078E-3, 8.1125000725E-4]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3,\n             8.2846998703E-4, 0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -\n             2.0865600090E-3, 2.3298799060E-3, -\n             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,\n             5.7109999034E-5, 9.7479997203E-4, 0.0000000000, -1.2145539746E-2, -\n             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -\n             4.4814897701E-3, 1.4807609841E-2,\n             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -\n             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,\n             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -\n             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,\n             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, 0.0000000000, -\n             1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,\n             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -\n             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -\n             9.7479997203E-4, -5.7109999034E-5,\n             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -\n             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, 0.0000000000, -8.2846998703E-4,\n             3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4,\n             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -\n             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,\n             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -\n             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,\n             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -\n             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,\n             -1.4179450460E-2, -8.5267601535E-3, -5.3605638444E-2, 0.1768419296, 0.0000000000, -\n             0.1768419296, 5.3605638444E-2, 8.5267601535E-3, 1.4179450460E-2,\n             -1.3955879956E-2, 2.8760801069E-3, -3.9442278445E-2, 0.1105690673, 0.0000000000, -\n             0.1105690673, 3.9442278445E-2, -2.8760801069E-3, 1.3955879956E-2,\n             -1.2316980399E-2, 5.8724298142E-3, -8.2581602037E-3, 2.2522680461E-2, 0.0000000000, -\n             2.2522680461E-2, 8.2581602037E-3, -5.8724298142E-3, 1.2316980399E-2,\n             -4.4451598078E-3, -4.4565401040E-3, 3.7740699481E-3, 3.6522001028E-4, 0.0000000000, -\n             3.6522001028E-4, -3.7740699481E-3, 4.4565401040E-3, 4.4451598078E-3,\n             8.1125000725E-4, -3.9103501476E-3, -1.3462699717E-3, -7.4700999539E-4, 0.0000000000, 7.4700999539E-4, 1.3462699717E-3, 3.9103501476E-3, -8.1125000725E-4]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [3.1221499667E-3, -8.6204400286E-3, -1.4807609841E-2, -8.0871898681E-3, -4.6670897864E-3, -4.0110000555E-5, 5.7109999034E-5, 8.2846998703E-4, 0.0000000000,\n             -8.6204400286E-3, -1.4917500317E-2, 4.4814897701E-3, -\n             2.3298799060E-3, 2.0865600090E-3, 6.9718998857E-3, 9.7479997203E-4, -\n             0.0000000000, -8.2846998703E-4,\n             -1.4807609841E-2, 4.4814897701E-3, -3.2785870135E-2, -\n             5.0797060132E-2, 2.4427289143E-2, 1.2145539746E-2, 0.0000000000, -\n             9.7479997203E-4, -5.7109999034E-5,\n             -8.0871898681E-3, -2.3298799060E-3, -5.0797060132E-2, 8.2495503128E-2, 0.1510555595, -\n             0.0000000000, -1.2145539746E-2, -6.9718998857E-3, 4.0110000555E-5,\n             -4.6670897864E-3, 2.0865600090E-3, 2.4427289143E-2, 0.1510555595, 0.0000000000, -\n             0.1510555595, -2.4427289143E-2, -2.0865600090E-3, 4.6670897864E-3,\n             -4.0110000555E-5, 6.9718998857E-3, 1.2145539746E-2, 0.0000000000, -\n             0.1510555595, -8.2495503128E-2, 5.0797060132E-2, 2.3298799060E-3, 8.0871898681E-3,\n             5.7109999034E-5, 9.7479997203E-4, -0.0000000000, -1.2145539746E-2, -\n             2.4427289143E-2, 5.0797060132E-2, 3.2785870135E-2, -\n             4.4814897701E-3, 1.4807609841E-2,\n             8.2846998703E-4, -0.0000000000, -9.7479997203E-4, -6.9718998857E-3, -\n             2.0865600090E-3, 2.3298799060E-3, -\n             4.4814897701E-3, 1.4917500317E-2, 8.6204400286E-3,\n             0.0000000000, -8.2846998703E-4, -5.7109999034E-5, 4.0110000555E-5, 4.6670897864E-3, 8.0871898681E-3, 1.4807609841E-2, 8.6204400286E-3, -3.1221499667E-3]\n        ).reshape(1, 1, 9, 9).permute(0, 1, 3, 2))\n\n    elif n_orientations == 6:\n        filters[\"l\"] = 2 * torch.tensor([\n            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -\n                0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404],\n            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,\n                0.00410600, -0.00661117, -0.00523281, -0.00244917],\n            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,\n                0.03277038, 0.01396746, -0.00661117, -0.00387812],\n            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,\n                0.06426333, 0.03277038, 0.00410600, -0.00944432],\n            [-0.00962054, 0.01002988, 0.03981393, 0.08169618, 0.10096540,\n                0.08169618, 0.03981393, 0.01002988, -0.00962054],\n            [-0.00944432, 0.00410600, 0.03277038, 0.06426333, 0.08169618,\n                0.06426333, 0.03277038, 0.00410600, -0.00944432],\n            [-0.00387812, -0.00661117, 0.01396746, 0.03277038, 0.03981393,\n                0.03277038, 0.01396746, -0.00661117, -0.00387812],\n            [-0.00244917, -0.00523281, -0.00661117, 0.00410600, 0.01002988,\n                0.00410600, -0.00661117, -0.00523281, -0.00244917],\n            [0.00085404, -0.00244917, -0.00387812, -0.00944432, -0.00962054, -0.00944432, -0.00387812, -0.00244917, 0.00085404]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"l0\"] = torch.tensor([\n            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614],\n            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],\n            [-0.03848215, 0.15925570, 0.40304148, 0.15925570, -0.03848215],\n            [-0.01551246, 0.05586982, 0.15925570, 0.05586982, -0.01551246],\n            [0.00341614, -0.01551246, -0.03848215, -0.01551246, 0.00341614]]\n        ).reshape(1, 1, 5, 5)\n        filters[\"h0\"] = torch.tensor([\n            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -\n                0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429],\n            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,\n                0.00631653, -0.00243812, -0.00350017, -0.00113093],\n            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -\n                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],\n            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -\n                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],\n            [-0.00080639, 0.01261227, -0.00981051, -0.11435863,\n                0.81380200, -0.11435863, -0.00981051, 0.01261227, -0.00080639],\n            [-0.00133542, 0.00631653, -0.00673482, -0.07027679, -\n                0.11435863, -0.07027679, -0.00673482, 0.00631653, -0.00133542],\n            [-0.00171484, -0.00243812, -0.00290081, -0.00673482, -\n                0.00981051, -0.00673482, -0.00290081, -0.00243812, -0.00171484],\n            [-0.00113093, -0.00350017, -0.00243812, 0.00631653, 0.01261227,\n                0.00631653, -0.00243812, -0.00350017, -0.00113093],\n            [-0.00033429, -0.00113093, -0.00171484, -0.00133542, -0.00080639, -0.00133542, -0.00171484, -0.00113093, -0.00033429]]\n        ).reshape(1, 1, 9, 9)\n        filters[\"b\"] = []\n        filters[\"b\"].append(torch.tensor([\n            0.00277643, 0.00496194, 0.01026699, 0.01455399, 0.01026699, 0.00496194, 0.00277643,\n            -0.00986904, -0.00893064, 0.01189859, 0.02755155, 0.01189859, -0.00893064, -0.00986904,\n            -0.01021852, -0.03075356, -0.08226445, -\n            0.11732297, -0.08226445, -0.03075356, -0.01021852,\n            0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,\n            0.01021852, 0.03075356, 0.08226445, 0.11732297, 0.08226445, 0.03075356, 0.01021852,\n            0.00986904, 0.00893064, -0.01189859, -\n            0.02755155, -0.01189859, 0.00893064, 0.00986904,\n            -0.00277643, -0.00496194, -0.01026699, -0.01455399, -0.01026699, -0.00496194, -0.00277643]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982,\n            -0.00358461, -0.01977507, -0.04084211, -\n            0.00228219, 0.03930573, 0.01161195, 0.00128000,\n            0.01047717, 0.01486305, -0.04819057, -\n            0.12227230, -0.05394139, 0.00853965, -0.00459034,\n            0.00790407, 0.04435647, 0.09454202, -0.00000000, -\n            0.09454202, -0.04435647, -0.00790407,\n            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,\n            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,\n            -0.01166982, -0.00285723, -0.00182078, -0.01124321, 0.00073141, 0.00640815, 0.00343249]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            0.00343249, 0.00358461, -0.01047717, -\n            0.00790407, -0.00459034, 0.00128000, 0.01166982,\n            0.00640815, 0.01977507, -0.01486305, -\n            0.04435647, 0.00853965, 0.01161195, 0.00285723,\n            0.00073141, 0.04084211, 0.04819057, -\n            0.09454202, -0.05394139, 0.03930573, 0.00182078,\n            -0.01124321, 0.00228219, 0.12227230, -\n            0.00000000, -0.12227230, -0.00228219, 0.01124321,\n            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -\n            0.04819057, -0.04084211, -0.00073141,\n            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,\n            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor(\n            [-0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643,\n             -0.00496194, 0.00893064, 0.03075356, -\n             0.00000000, -0.03075356, -0.00893064, 0.00496194,\n             -0.01026699, -0.01189859, 0.08226445, -\n             0.00000000, -0.08226445, 0.01189859, 0.01026699,\n             -0.01455399, -0.02755155, 0.11732297, -\n             0.00000000, -0.11732297, 0.02755155, 0.01455399,\n             -0.01026699, -0.01189859, 0.08226445, -\n             0.00000000, -0.08226445, 0.01189859, 0.01026699,\n             -0.00496194, 0.00893064, 0.03075356, -\n             0.00000000, -0.03075356, -0.00893064, 0.00496194,\n             -0.00277643, 0.00986904, 0.01021852, -0.00000000, -0.01021852, -0.00986904, 0.00277643]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.01166982, -0.00128000, 0.00459034, 0.00790407, 0.01047717, -0.00358461, -0.00343249,\n            -0.00285723, -0.01161195, -0.00853965, 0.04435647, 0.01486305, -0.01977507, -0.00640815,\n            -0.00182078, -0.03930573, 0.05394139, 0.09454202, -\n            0.04819057, -0.04084211, -0.00073141,\n            -0.01124321, 0.00228219, 0.12227230, -\n            0.00000000, -0.12227230, -0.00228219, 0.01124321,\n            0.00073141, 0.04084211, 0.04819057, -\n            0.09454202, -0.05394139, 0.03930573, 0.00182078,\n            0.00640815, 0.01977507, -0.01486305, -\n            0.04435647, 0.00853965, 0.01161195, 0.00285723,\n            0.00343249, 0.00358461, -0.01047717, -0.00790407, -0.00459034, 0.00128000, 0.01166982]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n        filters[\"b\"].append(torch.tensor([\n            -0.01166982, -0.00285723, -0.00182078, -\n            0.01124321, 0.00073141, 0.00640815, 0.00343249,\n            -0.00128000, -0.01161195, -0.03930573, 0.00228219, 0.04084211, 0.01977507, 0.00358461,\n            0.00459034, -0.00853965, 0.05394139, 0.12227230, 0.04819057, -0.01486305, -0.01047717,\n            0.00790407, 0.04435647, 0.09454202, -0.00000000, -\n            0.09454202, -0.04435647, -0.00790407,\n            0.01047717, 0.01486305, -0.04819057, -\n            0.12227230, -0.05394139, 0.00853965, -0.00459034,\n            -0.00358461, -0.01977507, -0.04084211, -\n            0.00228219, 0.03930573, 0.01161195, 0.00128000,\n            -0.00343249, -0.00640815, -0.00073141, 0.01124321, 0.00182078, 0.00285723, 0.01166982]\n        ).reshape(1, 1, 7, 7).permute(0, 1, 3, 2))\n\n    else:\n        raise Exception(\n            \"Steerable filters not implemented for %d orientations\" % n_orientations)\n\n    if filter_type == \"trained\":\n        if size == 5:\n            # TODO maybe also train h0 and l0 filters\n            filters = crop_steerable_pyramid_filters(filters, 5)\n            filters[\"b\"][0] = torch.tensor([\n                [-0.0356752239, -0.0223877281, -0.0009542659,\n                    0.0244821459, 0.0322226137],\n                [-0.0593218654,  0.1245803162, -\n                    0.0023863907, -0.1230178699, 0.0589442067],\n                [-0.0281576272,  0.2976626456, -\n                    0.0020888755, -0.2953369915, 0.0284542721],\n                [-0.0586092323,  0.1251581162, -\n                    0.0024624448, -0.1227868199, 0.0587830991],\n                [-0.0327464789, -0.0223652460, -\n                    0.0042342511,  0.0245472137, 0.0359398536]\n            ]).reshape(1, 1, 5, 5)\n            filters[\"b\"][1] = torch.tensor([\n                [3.9758663625e-02,  6.0679119080e-02,  3.0146904290e-02,\n                    6.1198268086e-02,  3.6218870431e-02],\n                [2.3255519569e-02, -1.2505133450e-01, -\n                    2.9738345742e-01, -1.2518258393e-01,  2.3592948914e-02],\n                [-1.3602430699e-03, -1.2058277935e-04,  2.6399988565e-04, -\n                    2.3791544663e-04,  1.8450465286e-03],\n                [-2.1563466638e-02,  1.2572696805e-01,  2.9745018482e-01,\n                    1.2458638102e-01, -2.3847281933e-02],\n                [-3.7941932678e-02, -6.1060950160e-02, -\n                    2.9489086941e-02, -6.0411967337e-02, -3.8459088653e-02]\n            ]).reshape(1, 1, 5, 5)\n\n            # Below filters were optimised on 09/02/2021\n            # 20K iterations with multiple images at more scales.\n            filters[\"b\"][0] = torch.tensor([\n                [-4.5508436859e-02, -2.1767273545e-02, -1.9399923622e-04,\n                    2.1200872958e-02,  4.5475799590e-02],\n                [-6.3554823399e-02,  1.2832683325e-01, -\n                    5.3858719184e-05, -1.2809979916e-01,  6.3842624426e-02],\n                [-3.4809380770e-02,  2.9954621196e-01,  2.9066693969e-05, -\n                    2.9957753420e-01,  3.4806568176e-02],\n                [-6.3934154809e-02,  1.2806062400e-01,  9.0917674243e-05, -\n                    1.2832444906e-01,  6.3572973013e-02],\n                [-4.5492250472e-02, -2.1125273779e-02,  4.2229349492e-04,\n                    2.1804777905e-02,  4.5236673206e-02]\n            ]).reshape(1, 1, 5, 5)\n            filters[\"b\"][1] = torch.tensor([\n                [4.8947390169e-02,  6.3575074077e-02,  3.4955859184e-02,\n                    6.4085893333e-02,  4.9838040024e-02],\n                [2.2061849013e-02, -1.2936264277e-01, -\n                    3.0093491077e-01, -1.2997294962e-01,  2.0597217605e-02],\n                [-5.1290717238e-05, -1.7305796064e-05,  2.0256420612e-05, -\n                    1.1864109547e-04,  7.3973249528e-05],\n                [-2.0749464631e-02,  1.2988376617e-01,  3.0080935359e-01,\n                    1.2921217084e-01, -2.2159902379e-02],\n                [-4.9614857882e-02, -6.4021714032e-02, -\n                    3.4676689655e-02, -6.3446544111e-02, -4.8282280564e-02]\n            ]).reshape(1, 1, 5, 5)\n\n            # Trained on 17/02/2021 to match fourier pyramid in spatial domain\n            filters[\"b\"][0] = torch.tensor([\n                [3.3370e-02,  9.3934e-02, -3.5810e-04, -9.4038e-02, -3.3115e-02],\n                [1.7716e-01,  3.9378e-01,  6.8461e-05, -3.9343e-01, -1.7685e-01],\n                [2.9213e-01,  6.1042e-01,  7.0654e-04, -6.0939e-01, -2.9177e-01],\n                [1.7684e-01,  3.9392e-01,  1.0517e-03, -3.9268e-01, -1.7668e-01],\n                [3.3000e-02,  9.4029e-02,  7.3565e-04, -9.3366e-02, -3.3008e-02]\n            ]).reshape(1, 1, 5, 5) * 0.1\n\n            filters[\"b\"][1] = torch.tensor([\n                [0.0331,  0.1763,  0.2907,  0.1753,  0.0325],\n                [0.0941,  0.3932,  0.6079,  0.3904,  0.0922],\n                [0.0008,  0.0009, -0.0010, -0.0025, -0.0015],\n                [-0.0929, -0.3919, -0.6097, -0.3944, -0.0946],\n                [-0.0328, -0.1760, -0.2915, -0.1768, -0.0333]\n            ]).reshape(1, 1, 5, 5) * 0.1\n\n        else:\n            raise Exception(\n                \"Trained filters not implemented for size %d\" % size)\n\n    if filter_type == \"cropped\":\n        filters = crop_steerable_pyramid_filters(filters, size)\n\n    return filters\n
"},{"location":"odak/learn_perception/#odak.learn.perception.util.slice_rgbd_targets","title":"slice_rgbd_targets(target, depth, depth_plane_positions)","text":"

Slices the target RGBD image and depth map into multiple layers based on depth plane positions.

Parameters:

  • target \u2013
                     The RGBD target tensor with shape (C, H, W).\n
  • depth \u2013
                     The depth map corresponding to the target image with shape (H, W).\n
  • depth_plane_positions \u2013
                     The positions of the depth planes used for slicing.\n

Returns:

  • targets ( Tensor ) \u2013

    A tensor of shape (N, C, H, W) where N is the number of depth planes. Contains the sliced targets for each depth plane.

  • masks ( Tensor ) \u2013

    A tensor of shape (N, C, H, W) containing binary masks for each depth plane.

Source code in odak/learn/perception/util.py
def slice_rgbd_targets(target, depth, depth_plane_positions):\n    \"\"\"\n    Slices the target RGBD image and depth map into multiple layers based on depth plane positions.\n\n    Parameters\n    ----------\n    target                 : torch.Tensor\n                             The RGBD target tensor with shape (C, H, W).\n    depth                  : torch.Tensor\n                             The depth map corresponding to the target image with shape (H, W).\n    depth_plane_positions  : list or torch.Tensor\n                             The positions of the depth planes used for slicing.\n\n    Returns\n    -------\n    targets              : torch.Tensor\n                           A tensor of shape (N, C, H, W) where N is the number of depth planes. Contains the sliced targets for each depth plane.\n    masks                : torch.Tensor\n                           A tensor of shape (N, C, H, W) containing binary masks for each depth plane.\n    \"\"\"\n    device = target.device\n    number_of_planes = len(depth_plane_positions) - 1\n    targets = torch.zeros(\n                        number_of_planes,\n                        target.shape[0],\n                        target.shape[1],\n                        target.shape[2],\n                        requires_grad = False,\n                        device = device\n                        )\n    masks = torch.zeros_like(targets, dtype = torch.int).to(device)\n    mask_zeros = torch.zeros_like(depth, dtype = torch.int)\n    mask_ones = torch.ones_like(depth, dtype = torch.int)\n    for i in range(1, number_of_planes+1):\n        for ch in range(target.shape[0]):\n            pos = depth_plane_positions[i] \n            prev_pos = depth_plane_positions[i-1] \n            if i <= (number_of_planes - 1):\n                condition = torch.logical_and(prev_pos <= depth, depth < pos)\n            else:\n                condition = torch.logical_and(prev_pos <= depth, depth <= pos)\n            mask = torch.where(condition, mask_ones, mask_zeros)\n            new_target = target[ch] * mask\n            targets[i-1, ch] = new_target.squeeze(0)\n            masks[i-1, ch] = mask.detach().clone() \n    return targets, masks\n
"},{"location":"odak/learn_raytracing/","title":"odak.learn.raytracing","text":"

odak.learn.raytracing

Provides necessary definitions for geometric optics. See \"General Ray tracing procedure\" from G.H. Spencerand M.V.R.K Murty for more theoratical explanation.

A class to represent a detector.

Source code in odak/learn/raytracing/detector.py
class detector():\n    \"\"\"\n    A class to represent a detector.\n    \"\"\"\n\n\n    def __init__(\n                 self,\n                 colors = 3,\n                 center = torch.tensor([0., 0., 0.]),\n                 tilt = torch.tensor([0., 0., 0.]),\n                 size = torch.tensor([10., 10.]),\n                 resolution = torch.tensor([100, 100]),\n                 device = torch.device('cpu')\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        colors         : int\n                         Number of color channels to register (e.g., RGB).\n        center         : torch.tensor\n                         Center point of the detector [3].\n        tilt           : torch.tensor\n                         Tilt angles of the surface in degrees [3].\n        size           : torch.tensor\n                         Size of the detector [2].\n        resolution     : torch.tensor\n                         Resolution of the detector.\n        device         : torch.device\n                         Device for computation (e.g., cuda, cpu).\n        \"\"\"\n        self.device = device\n        self.colors = colors\n        self.resolution = resolution.to(self.device)\n        self.surface_center = center.to(self.device)\n        self.surface_tilt = tilt.to(self.device)\n        self.size = size.to(self.device)\n        self.pixel_size = torch.tensor([\n                                        self.size[0] / self.resolution[0],\n                                        self.size[1] / self.resolution[1]\n                                       ], device  = self.device)\n        self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)\n        self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.\n        self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)\n        self.plane = define_plane(\n                                  point = self.surface_center,\n                                  angles = self.surface_tilt\n                                 )\n        self.pixel_locations, _, _, _ = grid_sample(\n                                                    size = self.size.tolist(),\n                                                    no = self.resolution.tolist(),\n                                                    center = self.surface_center.tolist(),\n                                                    angles = self.surface_tilt.tolist()\n                                                   )\n        self.pixel_locations = self.pixel_locations.to(self.device)\n        self.relu = torch.nn.ReLU()\n        self.clear()\n\n\n    def intersect(self, rays, color = 0):\n        \"\"\"\n        Function to intersect rays with the detector\n\n\n        Parameters\n        ----------\n        rays            : torch.tensor\n                          Rays to be intersected with a detector.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n        color           : int\n                          Color channel to register.\n\n        Returns\n        -------\n        points          : torch.tensor\n                          Intersection points with the image detector [k x 3].\n        \"\"\"\n        normals, _ = intersect_w_surface(rays, self.plane)\n        points = normals[:, 0]\n        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))\n        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))\n        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))\n        hit_x = torch.clamp(distances_x, min = 0., max = 1.)\n        hit_y = torch.clamp(distances_y, min = 0., max = 1.)\n        hit = hit_x * hit_y\n        image = torch.sum(hit, dim = 0)\n        self.image[color] += image.reshape(\n                                           self.image.shape[-2], \n                                           self.image.shape[-1]\n                                          )\n        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)\n        distance_image = distances\n#        distance_image = distances.reshape(\n#                                           -1,\n#                                           self.image.shape[-2],\n#                                           self.image.shape[-1]\n#                                          )\n        return points, image, distance_image\n\n\n    def get_image(self):\n        \"\"\"\n        Function to return the detector image.\n\n        Returns\n        -------\n        image           : torch.tensor\n                          Detector image.\n        \"\"\"\n        image = (self.image - self.image.min()) / (self.image.max() - self.image.min())\n        return image\n\n\n    def clear(self):\n        \"\"\"\n        Internal function to clear a detector.\n        \"\"\"\n        self.image = torch.zeros(\n\n                                 self.colors,\n                                 self.resolution[0],\n                                 self.resolution[1],\n                                 device = self.device,\n                                )\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector","title":"detector","text":"

A class to represent a detector.

Source code in odak/learn/raytracing/detector.py
class detector():\n    \"\"\"\n    A class to represent a detector.\n    \"\"\"\n\n\n    def __init__(\n                 self,\n                 colors = 3,\n                 center = torch.tensor([0., 0., 0.]),\n                 tilt = torch.tensor([0., 0., 0.]),\n                 size = torch.tensor([10., 10.]),\n                 resolution = torch.tensor([100, 100]),\n                 device = torch.device('cpu')\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        colors         : int\n                         Number of color channels to register (e.g., RGB).\n        center         : torch.tensor\n                         Center point of the detector [3].\n        tilt           : torch.tensor\n                         Tilt angles of the surface in degrees [3].\n        size           : torch.tensor\n                         Size of the detector [2].\n        resolution     : torch.tensor\n                         Resolution of the detector.\n        device         : torch.device\n                         Device for computation (e.g., cuda, cpu).\n        \"\"\"\n        self.device = device\n        self.colors = colors\n        self.resolution = resolution.to(self.device)\n        self.surface_center = center.to(self.device)\n        self.surface_tilt = tilt.to(self.device)\n        self.size = size.to(self.device)\n        self.pixel_size = torch.tensor([\n                                        self.size[0] / self.resolution[0],\n                                        self.size[1] / self.resolution[1]\n                                       ], device  = self.device)\n        self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)\n        self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.\n        self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)\n        self.plane = define_plane(\n                                  point = self.surface_center,\n                                  angles = self.surface_tilt\n                                 )\n        self.pixel_locations, _, _, _ = grid_sample(\n                                                    size = self.size.tolist(),\n                                                    no = self.resolution.tolist(),\n                                                    center = self.surface_center.tolist(),\n                                                    angles = self.surface_tilt.tolist()\n                                                   )\n        self.pixel_locations = self.pixel_locations.to(self.device)\n        self.relu = torch.nn.ReLU()\n        self.clear()\n\n\n    def intersect(self, rays, color = 0):\n        \"\"\"\n        Function to intersect rays with the detector\n\n\n        Parameters\n        ----------\n        rays            : torch.tensor\n                          Rays to be intersected with a detector.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n        color           : int\n                          Color channel to register.\n\n        Returns\n        -------\n        points          : torch.tensor\n                          Intersection points with the image detector [k x 3].\n        \"\"\"\n        normals, _ = intersect_w_surface(rays, self.plane)\n        points = normals[:, 0]\n        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))\n        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))\n        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))\n        hit_x = torch.clamp(distances_x, min = 0., max = 1.)\n        hit_y = torch.clamp(distances_y, min = 0., max = 1.)\n        hit = hit_x * hit_y\n        image = torch.sum(hit, dim = 0)\n        self.image[color] += image.reshape(\n                                           self.image.shape[-2], \n                                           self.image.shape[-1]\n                                          )\n        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)\n        distance_image = distances\n#        distance_image = distances.reshape(\n#                                           -1,\n#                                           self.image.shape[-2],\n#                                           self.image.shape[-1]\n#                                          )\n        return points, image, distance_image\n\n\n    def get_image(self):\n        \"\"\"\n        Function to return the detector image.\n\n        Returns\n        -------\n        image           : torch.tensor\n                          Detector image.\n        \"\"\"\n        image = (self.image - self.image.min()) / (self.image.max() - self.image.min())\n        return image\n\n\n    def clear(self):\n        \"\"\"\n        Internal function to clear a detector.\n        \"\"\"\n        self.image = torch.zeros(\n\n                                 self.colors,\n                                 self.resolution[0],\n                                 self.resolution[1],\n                                 device = self.device,\n                                )\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.__init__","title":"__init__(colors=3, center=torch.tensor([0.0, 0.0, 0.0]), tilt=torch.tensor([0.0, 0.0, 0.0]), size=torch.tensor([10.0, 10.0]), resolution=torch.tensor([100, 100]), device=torch.device('cpu'))","text":"

Parameters:

  • colors \u2013
             Number of color channels to register (e.g., RGB).\n
  • center \u2013
             Center point of the detector [3].\n
  • tilt \u2013
             Tilt angles of the surface in degrees [3].\n
  • size \u2013
             Size of the detector [2].\n
  • resolution \u2013
             Resolution of the detector.\n
  • device \u2013
             Device for computation (e.g., cuda, cpu).\n
Source code in odak/learn/raytracing/detector.py
def __init__(\n             self,\n             colors = 3,\n             center = torch.tensor([0., 0., 0.]),\n             tilt = torch.tensor([0., 0., 0.]),\n             size = torch.tensor([10., 10.]),\n             resolution = torch.tensor([100, 100]),\n             device = torch.device('cpu')\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    colors         : int\n                     Number of color channels to register (e.g., RGB).\n    center         : torch.tensor\n                     Center point of the detector [3].\n    tilt           : torch.tensor\n                     Tilt angles of the surface in degrees [3].\n    size           : torch.tensor\n                     Size of the detector [2].\n    resolution     : torch.tensor\n                     Resolution of the detector.\n    device         : torch.device\n                     Device for computation (e.g., cuda, cpu).\n    \"\"\"\n    self.device = device\n    self.colors = colors\n    self.resolution = resolution.to(self.device)\n    self.surface_center = center.to(self.device)\n    self.surface_tilt = tilt.to(self.device)\n    self.size = size.to(self.device)\n    self.pixel_size = torch.tensor([\n                                    self.size[0] / self.resolution[0],\n                                    self.size[1] / self.resolution[1]\n                                   ], device  = self.device)\n    self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)\n    self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.\n    self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)\n    self.plane = define_plane(\n                              point = self.surface_center,\n                              angles = self.surface_tilt\n                             )\n    self.pixel_locations, _, _, _ = grid_sample(\n                                                size = self.size.tolist(),\n                                                no = self.resolution.tolist(),\n                                                center = self.surface_center.tolist(),\n                                                angles = self.surface_tilt.tolist()\n                                               )\n    self.pixel_locations = self.pixel_locations.to(self.device)\n    self.relu = torch.nn.ReLU()\n    self.clear()\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.clear","title":"clear()","text":"

Internal function to clear a detector.

Source code in odak/learn/raytracing/detector.py
def clear(self):\n    \"\"\"\n    Internal function to clear a detector.\n    \"\"\"\n    self.image = torch.zeros(\n\n                             self.colors,\n                             self.resolution[0],\n                             self.resolution[1],\n                             device = self.device,\n                            )\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.get_image","title":"get_image()","text":"

Function to return the detector image.

Returns:

  • image ( tensor ) \u2013

    Detector image.

Source code in odak/learn/raytracing/detector.py
def get_image(self):\n    \"\"\"\n    Function to return the detector image.\n\n    Returns\n    -------\n    image           : torch.tensor\n                      Detector image.\n    \"\"\"\n    image = (self.image - self.image.min()) / (self.image.max() - self.image.min())\n    return image\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.intersect","title":"intersect(rays, color=0)","text":"

Function to intersect rays with the detector

Parameters:

  • rays \u2013
              Rays to be intersected with a detector.\n          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n
  • color \u2013
              Color channel to register.\n

Returns:

  • points ( tensor ) \u2013

    Intersection points with the image detector [k x 3].

Source code in odak/learn/raytracing/detector.py
    def intersect(self, rays, color = 0):\n        \"\"\"\n        Function to intersect rays with the detector\n\n\n        Parameters\n        ----------\n        rays            : torch.tensor\n                          Rays to be intersected with a detector.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n        color           : int\n                          Color channel to register.\n\n        Returns\n        -------\n        points          : torch.tensor\n                          Intersection points with the image detector [k x 3].\n        \"\"\"\n        normals, _ = intersect_w_surface(rays, self.plane)\n        points = normals[:, 0]\n        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))\n        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))\n        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))\n        hit_x = torch.clamp(distances_x, min = 0., max = 1.)\n        hit_y = torch.clamp(distances_y, min = 0., max = 1.)\n        hit = hit_x * hit_y\n        image = torch.sum(hit, dim = 0)\n        self.image[color] += image.reshape(\n                                           self.image.shape[-2], \n                                           self.image.shape[-1]\n                                          )\n        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)\n        distance_image = distances\n#        distance_image = distances.reshape(\n#                                           -1,\n#                                           self.image.shape[-2],\n#                                           self.image.shape[-1]\n#                                          )\n        return points, image, distance_image\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh","title":"planar_mesh","text":"Source code in odak/learn/raytracing/mesh.py
class planar_mesh():\n\n\n    def __init__(\n                 self,\n                 size = [1., 1.],\n                 number_of_meshes = [10, 10],\n                 angles = torch.tensor([0., 0., 0.]),\n                 offset = torch.tensor([0., 0., 0.]),\n                 device = torch.device('cpu'),\n                 heights = None\n                ):\n        \"\"\"\n        Definition to generate a plane with meshes.\n\n\n        Parameters\n        -----------\n        number_of_meshes  : torch.tensor\n                            Number of squares over plane.\n                            There are two triangles at each square.\n        size              : torch.tensor\n                            Size of the plane.\n        angles            : torch.tensor\n                            Rotation angles in degrees.\n        offset            : torch.tensor\n                            Offset along XYZ axes.\n                            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n        device            : torch.device\n                            Computational resource to be used (e.g., cpu, cuda).\n        heights           : torch.tensor\n                            Load surface heights from a tensor.\n        \"\"\"\n        self.device = device\n        self.angles = angles.to(self.device)\n        self.offset = offset.to(self.device)\n        self.size = size.to(self.device)\n        self.number_of_meshes = number_of_meshes.to(self.device)\n        self.init_heights(heights)\n\n\n    def init_heights(self, heights = None):\n        \"\"\"\n        Internal function to initialize a height map.\n        Note that self.heights is a differentiable variable, and can be optimized or learned.\n        See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.\n        \"\"\"\n        if not isinstance(heights, type(None)):\n            self.heights = heights.to(self.device)\n            self.heights.requires_grad = True\n        else:\n            self.heights = torch.zeros(\n                                       (self.number_of_meshes[0], self.number_of_meshes[1], 1),\n                                       requires_grad = True,\n                                       device = self.device,\n                                      )\n        x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) \n        y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)\n        X, Y = torch.meshgrid(x, y, indexing = 'ij')\n        self.X = X.unsqueeze(-1)\n        self.Y = Y.unsqueeze(-1)\n\n\n    def save_heights(self, filename = 'heights.pt'):\n        \"\"\"\n        Function to save heights to a file.\n\n        Parameters\n        ----------\n        filename          : str\n                            Filename.\n        \"\"\"\n        save_torch_tensor(filename, self.heights.detach().clone())\n\n\n    def save_heights_as_PLY(self, filename = 'mesh.ply'):\n        \"\"\"\n        Function to save mesh to a PLY file.\n\n        Parameters\n        ----------\n        filename          : str\n                            Filename.\n        \"\"\"\n        triangles = self.get_triangles()\n        write_PLY(triangles, filename)\n\n\n    def get_squares(self):\n        \"\"\"\n        Internal function to initiate squares over a plane.\n\n        Returns\n        -------\n        squares     : torch.tensor\n                      Squares over a plane.\n                      Expected size is [m x n x 3].\n        \"\"\"\n        squares = torch.cat((\n                             self.X,\n                             self.Y,\n                             self.heights\n                            ), dim = -1)\n        return squares\n\n\n    def get_triangles(self):\n        \"\"\"\n        Internal function to get triangles.\n        \"\"\" \n        squares = self.get_squares()\n        triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)\n        for i in range(0, self.number_of_meshes[0] - 1):\n            for j in range(0, self.number_of_meshes[1] - 1):\n                first_triangle = torch.cat((\n                                            squares[i + 1, j].unsqueeze(0),\n                                            squares[i + 1, j + 1].unsqueeze(0),\n                                            squares[i, j + 1].unsqueeze(0),\n                                           ), dim = 0)\n                second_triangle = torch.cat((\n                                             squares[i + 1, j].unsqueeze(0),\n                                             squares[i, j + 1].unsqueeze(0),\n                                             squares[i, j].unsqueeze(0),\n                                            ), dim = 0)\n                triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)\n                triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)\n        triangles = triangles.view(-1, 3, 3) + self.offset\n        return triangles \n\n\n    def mirror(self, rays):\n        \"\"\"\n        Function to bounce light rays off the meshes.\n\n        Parameters\n        ----------\n        rays              : torch.tensor\n                            Rays to be bounced.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n        Returns\n        -------\n        reflected_rays    : torch.tensor\n                            Reflected rays.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n        reflected_normals : torch.tensor\n                            Reflected normals.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n        \"\"\"\n        if len(rays.shape) == 2:\n            rays = rays.unsqueeze(0)\n        triangles = self.get_triangles()\n        reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n        reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n        for triangle in triangles:\n            _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(\n                                                                                        rays,\n                                                                                        triangle\n                                                                                       ) \n            triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)\n            if triangle_reflected_rays.shape[0] > 0:\n                reflected_rays = torch.cat((\n                                            reflected_rays,\n                                            triangle_reflected_rays\n                                          ))\n                reflected_normals = torch.cat((\n                                               reflected_normals,\n                                               intersecting_normals\n                                              ))\n        return reflected_rays, reflected_normals\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.__init__","title":"__init__(size=[1.0, 1.0], number_of_meshes=[10, 10], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([0.0, 0.0, 0.0]), device=torch.device('cpu'), heights=None)","text":"

Definition to generate a plane with meshes.

Parameters:

  • number_of_meshes \u2013
                Number of squares over plane.\n            There are two triangles at each square.\n
  • size \u2013
                Size of the plane.\n
  • angles \u2013
                Rotation angles in degrees.\n
  • offset \u2013
                Offset along XYZ axes.\n            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n
  • device \u2013
                Computational resource to be used (e.g., cpu, cuda).\n
  • heights \u2013
                Load surface heights from a tensor.\n
Source code in odak/learn/raytracing/mesh.py
def __init__(\n             self,\n             size = [1., 1.],\n             number_of_meshes = [10, 10],\n             angles = torch.tensor([0., 0., 0.]),\n             offset = torch.tensor([0., 0., 0.]),\n             device = torch.device('cpu'),\n             heights = None\n            ):\n    \"\"\"\n    Definition to generate a plane with meshes.\n\n\n    Parameters\n    -----------\n    number_of_meshes  : torch.tensor\n                        Number of squares over plane.\n                        There are two triangles at each square.\n    size              : torch.tensor\n                        Size of the plane.\n    angles            : torch.tensor\n                        Rotation angles in degrees.\n    offset            : torch.tensor\n                        Offset along XYZ axes.\n                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n    device            : torch.device\n                        Computational resource to be used (e.g., cpu, cuda).\n    heights           : torch.tensor\n                        Load surface heights from a tensor.\n    \"\"\"\n    self.device = device\n    self.angles = angles.to(self.device)\n    self.offset = offset.to(self.device)\n    self.size = size.to(self.device)\n    self.number_of_meshes = number_of_meshes.to(self.device)\n    self.init_heights(heights)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.get_squares","title":"get_squares()","text":"

Internal function to initiate squares over a plane.

Returns:

  • squares ( tensor ) \u2013

    Squares over a plane. Expected size is [m x n x 3].

Source code in odak/learn/raytracing/mesh.py
def get_squares(self):\n    \"\"\"\n    Internal function to initiate squares over a plane.\n\n    Returns\n    -------\n    squares     : torch.tensor\n                  Squares over a plane.\n                  Expected size is [m x n x 3].\n    \"\"\"\n    squares = torch.cat((\n                         self.X,\n                         self.Y,\n                         self.heights\n                        ), dim = -1)\n    return squares\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.get_triangles","title":"get_triangles()","text":"

Internal function to get triangles.

Source code in odak/learn/raytracing/mesh.py
def get_triangles(self):\n    \"\"\"\n    Internal function to get triangles.\n    \"\"\" \n    squares = self.get_squares()\n    triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)\n    for i in range(0, self.number_of_meshes[0] - 1):\n        for j in range(0, self.number_of_meshes[1] - 1):\n            first_triangle = torch.cat((\n                                        squares[i + 1, j].unsqueeze(0),\n                                        squares[i + 1, j + 1].unsqueeze(0),\n                                        squares[i, j + 1].unsqueeze(0),\n                                       ), dim = 0)\n            second_triangle = torch.cat((\n                                         squares[i + 1, j].unsqueeze(0),\n                                         squares[i, j + 1].unsqueeze(0),\n                                         squares[i, j].unsqueeze(0),\n                                        ), dim = 0)\n            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)\n            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)\n    triangles = triangles.view(-1, 3, 3) + self.offset\n    return triangles \n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.init_heights","title":"init_heights(heights=None)","text":"

Internal function to initialize a height map. Note that self.heights is a differentiable variable, and can be optimized or learned. See unit test test/test_learn_ray_detector.py or test/test_learn_ray_mesh.py as examples.

Source code in odak/learn/raytracing/mesh.py
def init_heights(self, heights = None):\n    \"\"\"\n    Internal function to initialize a height map.\n    Note that self.heights is a differentiable variable, and can be optimized or learned.\n    See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.\n    \"\"\"\n    if not isinstance(heights, type(None)):\n        self.heights = heights.to(self.device)\n        self.heights.requires_grad = True\n    else:\n        self.heights = torch.zeros(\n                                   (self.number_of_meshes[0], self.number_of_meshes[1], 1),\n                                   requires_grad = True,\n                                   device = self.device,\n                                  )\n    x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) \n    y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)\n    X, Y = torch.meshgrid(x, y, indexing = 'ij')\n    self.X = X.unsqueeze(-1)\n    self.Y = Y.unsqueeze(-1)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.mirror","title":"mirror(rays)","text":"

Function to bounce light rays off the meshes.

Parameters:

  • rays \u2013
                Rays to be bounced.\n            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n

Returns:

  • reflected_rays ( tensor ) \u2013

    Reflected rays. Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

  • reflected_normals ( tensor ) \u2013

    Reflected normals. Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

Source code in odak/learn/raytracing/mesh.py
def mirror(self, rays):\n    \"\"\"\n    Function to bounce light rays off the meshes.\n\n    Parameters\n    ----------\n    rays              : torch.tensor\n                        Rays to be bounced.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    Returns\n    -------\n    reflected_rays    : torch.tensor\n                        Reflected rays.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n    reflected_normals : torch.tensor\n                        Reflected normals.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    \"\"\"\n    if len(rays.shape) == 2:\n        rays = rays.unsqueeze(0)\n    triangles = self.get_triangles()\n    reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n    reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n    for triangle in triangles:\n        _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(\n                                                                                    rays,\n                                                                                    triangle\n                                                                                   ) \n        triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)\n        if triangle_reflected_rays.shape[0] > 0:\n            reflected_rays = torch.cat((\n                                        reflected_rays,\n                                        triangle_reflected_rays\n                                      ))\n            reflected_normals = torch.cat((\n                                           reflected_normals,\n                                           intersecting_normals\n                                          ))\n    return reflected_rays, reflected_normals\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.save_heights","title":"save_heights(filename='heights.pt')","text":"

Function to save heights to a file.

Parameters:

  • filename \u2013
                Filename.\n
Source code in odak/learn/raytracing/mesh.py
def save_heights(self, filename = 'heights.pt'):\n    \"\"\"\n    Function to save heights to a file.\n\n    Parameters\n    ----------\n    filename          : str\n                        Filename.\n    \"\"\"\n    save_torch_tensor(filename, self.heights.detach().clone())\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.planar_mesh.save_heights_as_PLY","title":"save_heights_as_PLY(filename='mesh.ply')","text":"

Function to save mesh to a PLY file.

Parameters:

  • filename \u2013
                Filename.\n
Source code in odak/learn/raytracing/mesh.py
def save_heights_as_PLY(self, filename = 'mesh.ply'):\n    \"\"\"\n    Function to save mesh to a PLY file.\n\n    Parameters\n    ----------\n    filename          : str\n                        Filename.\n    \"\"\"\n    triangles = self.get_triangles()\n    write_PLY(triangles, filename)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.center_of_triangle","title":"center_of_triangle(triangle)","text":"

Definition to calculate center of a triangle.

Parameters:

  • triangle \u2013
            An array that contains three points defining a triangle (Mx3). \n        It can also parallel process many triangles (NxMx3).\n

Returns:

  • centers ( tensor ) \u2013

    Triangle centers.

Source code in odak/learn/raytracing/primitives.py
def center_of_triangle(triangle):\n    \"\"\"\n    Definition to calculate center of a triangle.\n\n    Parameters\n    ----------\n    triangle      : torch.tensor\n                    An array that contains three points defining a triangle (Mx3). \n                    It can also parallel process many triangles (NxMx3).\n\n    Returns\n    -------\n    centers       : torch.tensor\n                    Triangle centers.\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.view((1, 3, 3))\n    center = torch.mean(triangle, axis=1)\n    return center\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.create_ray","title":"create_ray(xyz, abg, direction=False)","text":"

Definition to create a ray.

Parameters:

  • xyz \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • abg \u2013
           List that contains angles in degrees with respect to the X,Y and Z axes.\n       Size could be [1 x 3], [3], [m x 3].\n
  • direction \u2013
           If set to True, cosines of `abg` is not calculated.\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray. Size will be either [1 x 3] or [m x 3].

Source code in odak/learn/raytracing/ray.py
def create_ray(xyz, abg, direction = False):\n    \"\"\"\n    Definition to create a ray.\n\n    Parameters\n    ----------\n    xyz          : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    abg          : torch.tensor\n                   List that contains angles in degrees with respect to the X,Y and Z axes.\n                   Size could be [1 x 3], [3], [m x 3].\n    direction    : bool\n                   If set to True, cosines of `abg` is not calculated.\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray.\n                   Size will be either [1 x 3] or [m x 3].\n    \"\"\"\n    points = xyz\n    angles = abg\n    if len(xyz) == 1:\n        points = xyz.unsqueeze(0)\n    if len(abg) == 1:\n        angles = abg.unsqueeze(0)\n    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)\n    ray[:, 0] = points\n    if direction:\n        ray[:, 1] = abg\n    else:\n        ray[:, 1] = torch.cos(torch.deg2rad(abg))\n    return ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.create_ray_from_all_pairs","title":"create_ray_from_all_pairs(x0y0z0, x1y1z1)","text":"

Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.

Parameters:

  • x0y0z0 \u2013
           Tensor that contains X, Y, and Z start locations of rays.\n       Size should be [m x 3].\n
  • x1y1z1 \u2013
           Tensor that contains X, Y, and Z end locations of rays.\n       Size should be [n x 3].\n

Returns:

  • rays ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_all_pairs(x0y0z0, x1y1z1):\n    \"\"\"\n    Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.\n\n    Parameters\n    ----------\n    x0y0z0       : torch.tensor\n                   Tensor that contains X, Y, and Z start locations of rays.\n                   Size should be [m x 3].\n    x1y1z1       : torch.tensor\n                   Tensor that contains X, Y, and Z end locations of rays.\n                   Size should be [n x 3].\n\n    Returns\n    ----------\n    rays         : torch.tensor\n                   Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]\n    \"\"\"\n\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.unsqueeze(0)\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.unsqueeze(0)\n\n    m, n = x0y0z0.shape[0], x1y1z1.shape[0]\n    start_points = x0y0z0.unsqueeze(1).expand(-1, n, -1).reshape(-1, 3)\n    end_points = x1y1z1.unsqueeze(0).expand(m, -1, -1).reshape(-1, 3)\n\n    directions = end_points - start_points\n    norms = torch.norm(directions, p=2, dim=1, keepdim=True)\n    norms[norms == 0] = float('nan')\n\n    normalized_directions = directions / norms\n\n    rays = torch.zeros(m * n, 2, 3, device=x0y0z0.device)\n    rays[:, 0, :] = start_points\n    rays[:, 1, :] = normalized_directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.create_ray_from_grid_w_luminous_angle","title":"create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit)","text":"

Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.

Parameters:

center : torch.tensor The center point of the light array, shape [3]. size : list[int] The size of the light array [height, width] no : list[int] The number of the light arary [number of lights in height , number of lights inwidth] tilt : torch.tensor The tilt angles in degrees along x, y, z axes for the rays, shape [3]. angle_limit : float The maximum angle in degrees from the initial direction vector within which to emit rays. num_rays_per_light : int The number of rays each light should emit.

Returns:

rays : torch.tensor Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit):\n    \"\"\"\n    Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.\n\n    Parameters:\n    ----------\n    center              : torch.tensor\n                          The center point of the light array, shape [3].\n    size                : list[int]\n                          The size of the light array [height, width]\n    no                  : list[int]\n                          The number of the light arary [number of lights in height , number of lights inwidth]\n    tilt                : torch.tensor\n                          The tilt angles in degrees along x, y, z axes for the rays, shape [3].\n    angle_limit         : float\n                          The maximum angle in degrees from the initial direction vector within which to emit rays.\n    num_rays_per_light  : int\n                          The number of rays each light should emit.\n\n    Returns:\n    ----------\n    rays : torch.tensor\n           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]\n    \"\"\"\n\n    samples = torch.zeros((no[0], no[1], 3))\n\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((no[0]*no[1], 3))\n\n    samples, *_ = rotate_points(samples, angles=tilt)\n\n    samples = samples + center\n    angle_limit = torch.as_tensor(angle_limit)\n    cos_alpha = torch.cos(angle_limit * torch.pi / 180)\n    tilt = tilt * torch.pi / 180\n\n    theta = torch.acos(1 - 2 * torch.rand(num_ray_per_light*samples.size(0)) * (1-cos_alpha))\n    phi = 2 * torch.pi * torch.rand(num_ray_per_light*samples.size(0))  \n\n    directions = torch.stack([\n        torch.sin(theta) * torch.cos(phi),  \n        torch.sin(theta) * torch.sin(phi),  \n        torch.cos(theta)                    \n    ], dim=1)\n\n    c, s = torch.cos(tilt), torch.sin(tilt)\n\n    Rx = torch.tensor([\n        [1, 0, 0],\n        [0, c[0], -s[0]],\n        [0, s[0], c[0]]\n    ])\n\n    Ry = torch.tensor([\n        [c[1], 0, s[1]],\n        [0, 1, 0],\n        [-s[1], 0, c[1]]\n    ])\n\n    Rz = torch.tensor([\n        [c[2], -s[2], 0],\n        [s[2], c[2], 0],\n        [0, 0, 1]\n    ])\n\n    origins = samples.repeat(num_ray_per_light, 1)\n\n    directions = torch.matmul(directions, (Rz@Ry@Rx).T)\n\n\n    rays = torch.zeros(num_ray_per_light*samples.size(0), 2, 3)\n    rays[:, 0, :] = origins\n    rays[:, 1, :] = directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.create_ray_from_point_w_luminous_angle","title":"create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit)","text":"

Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.

Parameters:

origin : torch.tensor The origin point of the rays, shape [3]. num_rays : int The total number of rays to generate. tilt : torch.tensor The tilt angles in degrees along x, y, z axes, shape [3]. angle_limit : float The maximum angle in degrees from the initial direction vector within which to emit rays.

Returns:

rays : torch.tensor Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit):\n    \"\"\"\n    Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.\n\n    Parameters:\n    ----------\n    origin      : torch.tensor\n                  The origin point of the rays, shape [3].\n    num_rays    : int\n                  The total number of rays to generate.\n    tilt        : torch.tensor\n                  The tilt angles in degrees along x, y, z axes, shape [3].\n    angle_limit : float\n                  The maximum angle in degrees from the initial direction vector within which to emit rays.\n\n    Returns:\n    ----------\n    rays : torch.tensor\n           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]\n    \"\"\"\n    angle_limit = torch.as_tensor(angle_limit) \n    cos_alpha = torch.cos(angle_limit * torch.pi / 180)\n    tilt = tilt * torch.pi / 180\n\n    theta = torch.acos(1 - 2 * torch.rand(num_ray) * (1-cos_alpha))\n    phi = 2 * torch.pi * torch.rand(num_ray)  \n\n\n    directions = torch.stack([\n        torch.sin(theta) * torch.cos(phi),  \n        torch.sin(theta) * torch.sin(phi),  \n        torch.cos(theta)                    \n    ], dim=1)\n\n    c, s = torch.cos(tilt), torch.sin(tilt)\n\n    Rx = torch.tensor([\n        [1, 0, 0],\n        [0, c[0], -s[0]],\n        [0, s[0], c[0]]\n    ])\n\n    Ry = torch.tensor([\n        [c[1], 0, s[1]],\n        [0, 1, 0],\n        [-s[1], 0, c[1]]\n    ])\n\n    Rz = torch.tensor([\n        [c[2], -s[2], 0],\n        [s[2], c[2], 0],\n        [0, 0, 1]\n    ])\n\n    origins = origin.repeat(num_ray, 1)\n    directions = torch.matmul(directions, (Rz@Ry@Rx).T)\n\n\n    rays = torch.zeros(num_ray, 2, 3)\n    rays[:, 0, :] = origins\n    rays[:, 1, :] = directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.create_ray_from_two_points","title":"create_ray_from_two_points(x0y0z0, x1y1z1)","text":"

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray or batch of rays.\n       Size could be [1 x 3], [3], [m x 3].\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray(s).

Source code in odak/learn/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    x1y1z1       : torch.tensor\n                   List that contains X,Y and Z ending locations of a ray or batch of rays.\n                   Size could be [1 x 3], [3], [m x 3].\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray(s).\n    \"\"\"\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.unsqueeze(0)\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.unsqueeze(0)\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5\n    s[s == 0] = float('nan')\n    cosines = torch.zeros_like(x0y0z0 * x1y1z1)\n    cosines[:, 0] = xdiff / s\n    cosines[:, 1] = ydiff / s\n    cosines[:, 2] = zdiff / s\n    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    return ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.define_circle","title":"define_circle(center, radius, angles)","text":"

Definition to describe a circle in a single variable packed form.

Parameters:

  • center \u2013
      Center of a circle to be defined in 3D space.\n
  • radius \u2013
      Radius of a circle to be defined.\n
  • angles \u2013
      Angular tilt of a circle represented by rotations about x, y, and z axes.\n

Returns:

  • circle ( list ) \u2013

    Single variable packed form.

Source code in odak/learn/raytracing/primitives.py
def define_circle(center, radius, angles):\n    \"\"\"\n    Definition to describe a circle in a single variable packed form.\n\n    Parameters\n    ----------\n    center  : torch.Tensor\n              Center of a circle to be defined in 3D space.\n    radius  : float\n              Radius of a circle to be defined.\n    angles  : torch.Tensor\n              Angular tilt of a circle represented by rotations about x, y, and z axes.\n\n    Returns\n    ----------\n    circle  : list\n              Single variable packed form.\n    \"\"\"\n    points = define_plane(center, angles=angles)\n    circle = [\n        points,\n        center,\n        torch.tensor([radius])\n    ]\n    return circle\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.define_plane","title":"define_plane(point, angles=torch.tensor([0.0, 0.0, 0.0]))","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • point \u2013
           A point that is at the center of a plane.\n
  • angles \u2013
           Rotation angles in degrees.\n

Returns:

  • plane ( tensor ) \u2013

    Points defining plane.

Source code in odak/learn/raytracing/primitives.py
def define_plane(point, angles = torch.tensor([0., 0., 0.])):\n    \"\"\" \n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    point        : torch.tensor\n                   A point that is at the center of a plane.\n    angles       : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    plane        : torch.tensor\n                   Points defining plane.\n    \"\"\"\n    plane = torch.tensor([\n                          [10., 10., 0.],\n                          [0., 10., 0.],\n                          [0.,  0., 0.]\n                         ], device = point.device)\n    for i in range(0, plane.shape[0]):\n        plane[i], _, _, _ = rotate_points(plane[i], angles = angles.to(point.device))\n        plane[i] = plane[i] + point\n    return plane\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.define_plane_mesh","title":"define_plane_mesh(number_of_meshes=[10, 10], size=[1.0, 1.0], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([[0.0, 0.0, 0.0]]))","text":"

Definition to generate a plane with meshes.

Parameters:

  • number_of_meshes \u2013
                Number of squares over plane.\n            There are two triangles at each square.\n
  • size \u2013
                Size of the plane.\n
  • angles \u2013
                Rotation angles in degrees.\n
  • offset \u2013
                Offset along XYZ axes.\n            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n

Returns:

  • triangles ( tensor ) \u2013

    Triangles [m x 3 x 3], where m is 2 * number_of_meshes[0] times number_of_meshes[1].

Source code in odak/learn/raytracing/primitives.py
def define_plane_mesh(\n                      number_of_meshes = [10, 10], \n                      size = [1., 1.], \n                      angles = torch.tensor([0., 0., 0.]), \n                      offset = torch.tensor([[0., 0., 0.]])\n                     ):\n    \"\"\"\n    Definition to generate a plane with meshes.\n\n\n    Parameters\n    -----------\n    number_of_meshes  : torch.tensor\n                        Number of squares over plane.\n                        There are two triangles at each square.\n    size              : list\n                        Size of the plane.\n    angles            : torch.tensor\n                        Rotation angles in degrees.\n    offset            : torch.tensor\n                        Offset along XYZ axes.\n                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`. \n\n    Returns\n    -------\n    triangles         : torch.tensor\n                        Triangles [m x 3 x 3], where m is `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n    \"\"\"\n    triangles = torch.zeros(2, number_of_meshes[0], number_of_meshes[1], 3, 3)\n    step = [size[0] / number_of_meshes[0], size[1] / number_of_meshes[1]]\n    for i in range(0, number_of_meshes[0] - 1):\n        for j in range(0, number_of_meshes[1] - 1):\n            first_triangle = torch.tensor([\n                                           [       -size[0] / 2. + step[0] * i,       -size[1] / 2. + step[0] * j, 0.],\n                                           [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],\n                                           [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]\n                                          ])\n            second_triangle = torch.tensor([\n                                            [ -size[0] / 2. + step[0] * (i + 1), -size[1] / 2. + step[0] * (j + 1), 0.],\n                                            [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],\n                                            [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]\n                                           ])\n            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = angles)\n            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = angles)\n    triangles = triangles.view(-1, 3, 3) + offset\n    return triangles\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.define_sphere","title":"define_sphere(center=torch.tensor([[0.0, 0.0, 0.0]]), radius=torch.tensor([1.0]))","text":"

Definition to define a sphere.

Parameters:

  • center \u2013
          Center of the sphere(s) along XYZ axes.\n      Expected size is [3], [1, 3] or [m, 3].\n
  • radius \u2013
          Radius of that sphere(s).\n      Expected size is [1], [1, 1], [m] or [m, 1].\n

Returns:

  • parameters ( tensor ) \u2013

    Parameters of defined sphere(s). Expected size is [1, 3] or [m x 3].

Source code in odak/learn/raytracing/primitives.py
def define_sphere(center = torch.tensor([[0., 0., 0.]]), radius = torch.tensor([1.])):\n    \"\"\"\n    Definition to define a sphere.\n\n    Parameters\n    ----------\n    center      : torch.tensor\n                  Center of the sphere(s) along XYZ axes.\n                  Expected size is [3], [1, 3] or [m, 3].\n    radius      : torch.tensor\n                  Radius of that sphere(s).\n                  Expected size is [1], [1, 1], [m] or [m, 1].\n\n    Returns\n    -------\n    parameters  : torch.tensor\n                  Parameters of defined sphere(s).\n                  Expected size is [1, 3] or [m x 3].\n    \"\"\"\n    if len(radius.shape) == 1:\n        radius = radius.unsqueeze(0)\n    if len(center.shape) == 1:\n        center = center.unsqueeze(1)\n    parameters = torch.cat((center, radius), dim = 1)\n    return parameters\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( Tensor ) \u2013

    Distance in between given two points.

Source code in odak/learn/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : torch.Tensor\n                  First point in X,Y,Z.\n    point2      : torch.Tensor\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : torch.Tensor\n                  Distance in between given two points.\n    \"\"\"\n    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1\n    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2\n\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))\n\n    return distance\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.get_sphere_normal_torch","title":"get_sphere_normal_torch(point, sphere)","text":"

Definition to get a normal of a point on a given sphere.

Parameters:

  • point \u2013
            Point on sphere in X,Y,Z.\n
  • sphere \u2013
            Center defined in X,Y,Z and radius.\n

Returns:

  • normal_vector ( tensor ) \u2013

    Normal vector.

Source code in odak/learn/raytracing/boundary.py
def get_sphere_normal_torch(point, sphere):\n    \"\"\"\n    Definition to get a normal of a point on a given sphere.\n\n    Parameters\n    ----------\n    point         : torch.tensor\n                    Point on sphere in X,Y,Z.\n    sphere        : torch.tensor\n                    Center defined in X,Y,Z and radius.\n\n    Returns\n    ----------\n    normal_vector : torch.tensor\n                    Normal vector.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    normal_vector = create_ray_from_two_points(point, sphere[0:3])\n    return normal_vector\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.get_triangle_normal","title":"get_triangle_normal(triangle, triangle_center=None)","text":"

Definition to calculate surface normal of a triangle.

Parameters:

  • triangle \u2013
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n
  • triangle_center (tensor, default: None ) \u2013
              Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection.

Source code in odak/learn/raytracing/boundary.py
def get_triangle_normal(triangle, triangle_center=None):\n    \"\"\"\n    Definition to calculate surface normal of a triangle.\n\n    Parameters\n    ----------\n    triangle        : torch.tensor\n                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n    triangle_center : torch.tensor\n                      Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n\n    Returns\n    ----------\n    normal          : torch.tensor\n                      Surface normal at the point of intersection.\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.view((1, 3, 3))\n    normal = torch.zeros((triangle.shape[0], 2, 3)).to(triangle.device)\n    direction = torch.linalg.cross(\n                                   triangle[:, 0] - triangle[:, 1], \n                                   triangle[:, 2] - triangle[:, 1]\n                                  )\n    if type(triangle_center) == type(None):\n        normal[:, 0] = center_of_triangle(triangle)\n    else:\n        normal[:, 0] = triangle_center\n    normal[:, 1] = direction / torch.sum(direction, axis=1)[0]\n    if normal.shape[0] == 1:\n        normal = normal.view((2, 3))\n    return normal\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.grid_sample","title":"grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( tensor ) \u2013

    Samples generated.

  • rotx ( tensor ) \u2013

    Rotation matrix at X axis.

  • roty ( tensor ) \u2013

    Rotation matrix at Y axis.

  • rotz ( tensor ) \u2013

    Rotation matrix at Z axis.

Source code in odak/learn/tools/sample.py
def grid_sample(\n                no = [10, 10],\n                size = [100., 100.], \n                center = [0., 0., 0.], \n                angles = [0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    -------\n    samples     : torch.tensor\n                  Samples generated.\n    rotx        : torch.tensor\n                  Rotation matrix at X axis.\n    roty        : torch.tensor\n                  Rotation matrix at Y axis.\n    rotz        : torch.tensor\n                  Rotation matrix at Z axis.\n    \"\"\"\n    center = torch.tensor(center)\n    angles = torch.tensor(angles)\n    size = torch.tensor(size)\n    samples = torch.zeros((no[0], no[1], 3))\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))\n    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)\n    return samples, rotx, roty, rotz\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_circle","title":"intersect_w_circle(ray, circle)","text":"

Definition to find intersection point of a ray with a circle. Returns distance as zero if there isn't an intersection.

Parameters:

  • ray \u2013
           A vector/ray.\n
  • circle \u2013
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n

Returns:

  • normal ( Tensor ) \u2013

    Surface normal at the point of intersection.

  • distance ( Tensor ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_circle(ray, circle):\n    \"\"\"\n    Definition to find intersection point of a ray with a circle. \n    Returns distance as zero if there isn't an intersection.\n\n    Parameters\n    ----------\n    ray          : torch.Tensor\n                   A vector/ray.\n    circle       : list\n                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n\n    Returns\n    ----------\n    normal       : torch.Tensor\n                   Surface normal at the point of intersection.\n    distance     : torch.Tensor\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, circle[0])\n\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n\n    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])\n    mask = distance_to_center > circle[2]\n    distance[mask] = 0\n\n    if len(ray.shape) == 2:\n        normal = normal.squeeze(0)\n\n    return normal, distance\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_sphere","title":"intersect_w_sphere(ray, sphere, learning_rate=0.2, number_of_steps=5000, error_threshold=0.01)","text":"

Definition to find the intersection between ray(s) and sphere(s).

Parameters:

  • ray \u2013
                  Input ray(s).\n              Expected size is [1 x 2 x 3] or [m x 2 x 3].\n
  • sphere \u2013
                  Input sphere.\n              Expected size is [1 x 4].\n
  • learning_rate \u2013
                  Learning rate used in the optimizer for finding the propagation distances of the rays.\n
  • number_of_steps \u2013
                  Number of steps used in the optimizer.\n
  • error_threshold \u2013
                  The error threshold that will help deciding intersection or no intersection.\n

Returns:

  • intersecting_ray ( tensor ) \u2013

    Ray(s) that intersecting with the given sphere. Expected size is [n x 2 x 3], where n could be any real number.

  • intersecting_normal ( tensor ) \u2013

    Normal(s) for the ray(s) intersecting with the given sphere Expected size is [n x 2 x 3], where n could be any real number.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):\n    \"\"\"\n    Definition to find the intersection between ray(s) and sphere(s).\n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          Input ray(s).\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    sphere              : torch.tensor\n                          Input sphere.\n                          Expected size is [1 x 4].\n    learning_rate       : float\n                          Learning rate used in the optimizer for finding the propagation distances of the rays.\n    number_of_steps     : int\n                          Number of steps used in the optimizer.\n    error_threshold     : float\n                          The error threshold that will help deciding intersection or no intersection.\n\n    Returns\n    -------\n    intersecting_ray    : torch.tensor\n                          Ray(s) that intersecting with the given sphere.\n                          Expected size is [n x 2 x 3], where n could be any real number.\n    intersecting_normal : torch.tensor\n                          Normal(s) for the ray(s) intersecting with the given sphere\n                          Expected size is [n x 2 x 3], where n could be any real number.\n\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(sphere.shape) == 1:\n        sphere = sphere.unsqueeze(0)\n    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)\n    loss_l2 = torch.nn.MSELoss(reduction = 'sum')\n    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    \n    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)\n    for step in t:\n        optimizer.zero_grad()\n        propagated_ray = propagate_ray(ray, distance)\n        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)\n        loss = loss_l2(\n                       test,\n                       torch.zeros_like(test)\n                      )\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description('Sphere intersection loss: {}'.format(loss.item()))\n    check = test < error_threshold\n    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])\n    intersecting_normal = create_ray_from_two_points(\n                                                     sphere[:, 0:3],\n                                                     intersecting_ray[:, 0]\n                                                    )\n    return intersecting_ray, intersecting_normal, distance, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_surface","title":"intersect_w_surface(ray, points)","text":"

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

Parameters:

  • ray \u2013
           A vector/ray.\n
  • points \u2013
           Set of points in X,Y and Z to define a planar surface.\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_surface(ray, points):\n    \"\"\"\n    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray.\n    points       : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface.\n\n    Returns\n    ----------\n    normal       : torch.tensor\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between starting point of a ray with it's intersection with a planar surface.\n    \"\"\"\n    normal = get_triangle_normal(points)\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(points.shape) == 2:\n        points = points.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    f = normal[:, 0] - ray[:, 0]\n    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T\n    new_normal = torch.zeros_like(ray)\n    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]\n    new_normal[:, 1] = normal[:, 1]\n    new_normal = torch.nan_to_num(\n                                  new_normal,\n                                  nan = float('nan'),\n                                  posinf = float('nan'),\n                                  neginf = float('nan')\n                                 )\n    distance = torch.nan_to_num(\n                                distance,\n                                nan = float('nan'),\n                                posinf = float('nan'),\n                                neginf = float('nan')\n                               )\n    return new_normal, distance\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_surface_batch","title":"intersect_w_surface_batch(ray, triangle)","text":"

Parameters:

  • ray \u2013
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection (m x n x 2 x 3).

  • distance ( tensor ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

Source code in odak/learn/raytracing/boundary.py
def intersect_w_surface_batch(ray, triangle):\n    \"\"\"\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal       : torch.tensor\n                   Surface normal at the point of intersection (m x n x 2 x 3).\n    distance     : torch.tensor\n                   Distance in between starting point of a ray with it's intersection with a planar surface (m x n).\n    \"\"\"\n    normal = get_triangle_normal(triangle)\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n\n    f = normal[:, None, 0] - ray[None, :, 0]\n    distance = (torch.bmm(normal[:, None, 1], f.permute(0, 2, 1)).squeeze(1) / torch.mm(normal[:, 1], ray[:, 1].T)).T\n\n    new_normal = torch.zeros((triangle.shape[0], )+ray.shape)\n    new_normal[:, :, 0] = ray[None, :, 0] + (distance[:, :, None] * ray[:, None, 1]).permute(1, 0, 2)\n    new_normal[:, :, 1] = normal[:, None, 1]\n    new_normal = torch.nan_to_num(\n                                  new_normal,\n                                  nan = float('nan'),\n                                  posinf = float('nan'),\n                                  neginf = float('nan')\n                                 )\n    distance = torch.nan_to_num(\n                                distance,\n                                nan = float('nan'),\n                                posinf = float('nan'),\n                                neginf = float('nan')\n                               )\n    return new_normal, distance.T\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_triangle","title":"intersect_w_triangle(ray, triangle)","text":"

Definition to find intersection point of a ray with a triangle.

Parameters:

  • ray \u2013
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n
  • triangle \u2013
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection with the surface of triangle. This could also involve surface normals that are not on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle. Expected size is [1 x 1] or [m x 1] depending on the input.

  • intersecting_ray ( tensor ) \u2013

    Rays that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • intersecting_normal ( tensor ) \u2013

    Normals that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • check ( tensor ) \u2013

    A list that provides a bool as True or False for each ray used as input. A test to see is a ray could be on the given triangle. Expected size is [1] or [m].

Source code in odak/learn/raytracing/boundary.py
def intersect_w_triangle(ray, triangle):\n    \"\"\"\n    Definition to find intersection point of a ray with a triangle. \n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n    triangle            : torch.tensor\n                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n\n    Returns\n    ----------\n    normal              : torch.tensor\n                          Surface normal at the point of intersection with the surface of triangle.\n                          This could also involve surface normals that are not on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    distance            : float\n                          Distance in between a starting point of a ray and the intersection point with a given triangle.\n                          Expected size is [1 x 1] or [m x 1] depending on the input.\n    intersecting_ray    : torch.tensor\n                          Rays that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    intersecting_normal : torch.tensor\n                          Normals that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    check               : torch.tensor\n                          A list that provides a bool as True or False for each ray used as input.\n                          A test to see is a ray could be on the given triangle.\n                          Expected size is [1] or [m].\n    \"\"\"\n    if len(triangle.shape) == 2:\n       triangle = triangle.unsqueeze(0)\n    if len(ray.shape) == 2:\n       ray = ray.unsqueeze(0)\n    normal, distance = intersect_w_surface(ray, triangle)\n    check = is_it_on_triangle(normal[:, 0], triangle)\n    intersecting_ray = ray.unsqueeze(0)\n    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_ray = intersecting_ray[check == True]\n    intersecting_normal = normal.unsqueeze(0)\n    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_normal = intersecting_normal[check ==  True]\n    return normal, distance, intersecting_ray, intersecting_normal, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.intersect_w_triangle_batch","title":"intersect_w_triangle_batch(ray, triangle)","text":"

Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.

Parameters:

  • ray \u2013
           vectors/rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define triangles (m x 3 x 3).\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection (m x n x 2 x 3).

  • distance ( List ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

  • intersect_ray ( List ) \u2013

    List of intersecting rays (k x 2 x 3) where k <= n.

  • intersect_normal ( List ) \u2013

    List of intersecting normals (k x 2 x 3) where k <= n*m.

  • check ( tensor ) \u2013

    Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_triangle_batch(ray, triangle):\n    \"\"\"\n    Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   vectors/rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal          : torch.tensor\n                      Surface normal at the point of intersection (m x n x 2 x 3).\n    distance        : List\n                      Distance in between starting point of a ray with it's intersection with a planar surface (m x n).\n    intersect_ray   : List\n                      List of intersecting rays (k x 2 x 3) where k <= n.\n    intersect_normal: List\n                      List of intersecting normals (k x 2 x 3) where k <= n*m.\n    check           : torch.tensor\n                      Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.\n    \"\"\"\n    if len(triangle.shape) == 2:\n       triangle = triangle.unsqueeze(0)\n    if len(ray.shape) == 2:\n       ray = ray.unsqueeze(0)\n\n    normal, distance = intersect_w_surface_batch(ray, triangle)\n\n    check = is_it_on_triangle_batch(normal[:, :, 0], triangle)\n\n    flat_check = check.flatten()\n    flat_normal = normal.view(-1, normal.size(-2), normal.size(-1))\n    flat_ray = ray.repeat(normal.size(0), 1, 1)\n    flat_distance = distance.flatten()\n\n    filtered_normal = torch.masked_select(flat_normal, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))\n    filtered_ray = torch.masked_select(flat_ray, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))\n    filtered_distnace = torch.masked_select(flat_distance, flat_check)\n\n    check_count = check.sum(dim=1).tolist()\n    split_size_ray_and_normal = [count * 2 * 3 for count in check_count]\n    split_size_distance = [count for count in check_count]\n\n    normal_grouped = torch.split(filtered_normal, split_size_ray_and_normal)\n    ray_grouped = torch.split(filtered_ray, split_size_ray_and_normal)\n    distance_grouped = torch.split(filtered_distnace, split_size_distance)\n\n    intersecting_normal = [g.view(-1, 2, 3) for g in normal_grouped if g.numel() > 0]\n    intersecting_ray = [g.view(-1, 2, 3) for g in ray_grouped if g.numel() > 0]\n    new_distance = [g for g in distance_grouped if g.numel() > 0]\n\n    return normal, new_distance, intersecting_ray, intersecting_normal, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.is_it_on_triangle","title":"is_it_on_triangle(point_to_check, triangle)","text":"

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True. For more details, visit: https://blackpawn.com/texts/pointinpoly/.

Parameters:

  • point_to_check \u2013
              Point(s) to check.\n          Expected size is [3], [1 x 3] or [m x 3].\n
  • triangle \u2013
              Triangle described with three points.\n          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n

Returns:

  • result ( tensor ) \u2013

    Is it on a triangle? Returns NaN if condition not satisfied. Expected size is [1] or [m] depending on the input.

Source code in odak/learn/raytracing/primitives.py
def is_it_on_triangle(point_to_check, triangle):\n    \"\"\"\n    Definition to check if a given point is inside a triangle. \n    If the given point is inside a defined triangle, this definition returns True.\n    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).\n\n    Parameters\n    ----------\n    point_to_check  : torch.tensor\n                      Point(s) to check.\n                      Expected size is [3], [1 x 3] or [m x 3].\n    triangle        : torch.tensor\n                      Triangle described with three points.\n                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n\n    Returns\n    -------\n    result          : torch.tensor\n                      Is it on a triangle? Returns NaN if condition not satisfied.\n                      Expected size is [1] or [m] depending on the input.\n    \"\"\"\n    if len(point_to_check.shape) == 1:\n        point_to_check = point_to_check.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    v0 = triangle[:, 2] - triangle[:, 0]\n    v1 = triangle[:, 1] - triangle[:, 0]\n    v2 = point_to_check - triangle[:, 0]\n    if len(v0.shape) == 1:\n        v0 = v0.unsqueeze(0)\n    if len(v1.shape) == 1:\n        v1 = v1.unsqueeze(0)\n    if len(v2.shape) == 1:\n        v2 = v2.unsqueeze(0)\n    dot00 = torch.mm(v0, v0.T)\n    dot01 = torch.mm(v0, v1.T)\n    dot02 = torch.mm(v0, v2.T) \n    dot11 = torch.mm(v1, v1.T)\n    dot12 = torch.mm(v1, v2.T)\n    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)\n    u = (dot11 * dot02 - dot01 * dot12) * invDenom\n    v = (dot00 * dot12 - dot01 * dot02) * invDenom\n    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)\n    return result\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.is_it_on_triangle_batch","title":"is_it_on_triangle_batch(point_to_check, triangle)","text":"

Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.

Parameters:

  • point_to_check \u2013
              Points to check (m x n x 3).\n
  • triangle \u2013
              Triangles (m x 3 x 3).\n

Returns:

  • result ( torch.tensor (m x n) ) \u2013
Source code in odak/learn/raytracing/primitives.py
def is_it_on_triangle_batch(point_to_check, triangle):\n    \"\"\"\n    Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.\n\n    Parameters\n    ----------\n    point_to_check  : torch.tensor\n                      Points to check (m x n x 3).\n    triangle        : torch.tensor \n                      Triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    result          : torch.tensor (m x n)\n\n    \"\"\"\n    if len(point_to_check.shape) == 1:\n        point_to_check = point_to_check.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    v0 = triangle[:, 2] - triangle[:, 0]\n    v1 = triangle[:, 1] - triangle[:, 0]\n    v2 = point_to_check - triangle[:, None, 0]\n    if len(v0.shape) == 1:\n        v0 = v0.unsqueeze(0)\n    if len(v1.shape) == 1:\n        v1 = v1.unsqueeze(0)\n    if len(v2.shape) == 1:\n        v2 = v2.unsqueeze(0)\n\n    dot00 = torch.bmm(v0.unsqueeze(1), v0.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot01 = torch.bmm(v0.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot02 = torch.bmm(v0.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)\n    dot11 = torch.bmm(v1.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot12 = torch.bmm(v1.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)\n    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)\n    u = (dot11 * dot02 - dot01 * dot12) * invDenom\n    v = (dot00 * dot12 - dot01 * dot02) * invDenom\n    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)\n\n    return result\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.propagate_ray","title":"propagate_ray(ray, distance)","text":"

Definition to propagate a ray at a certain given distance.

Parameters:

  • ray \u2013
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n
  • distance \u2013
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n

Returns:

  • new_ray ( tensor ) \u2013

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

Source code in odak/learn/raytracing/ray.py
def propagate_ray(ray, distance):\n    \"\"\"\n    Definition to propagate a ray at a certain given distance.\n\n    Parameters\n    ----------\n    ray        : torch.tensor\n                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n    distance   : torch.tensor\n                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n\n    Returns\n    ----------\n    new_ray    : torch.tensor\n                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(distance.shape) == 2:\n        distance = distance.squeeze(-1)\n    new_ray = torch.zeros_like(ray)\n    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]\n    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]\n    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]\n    return new_ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.reflect","title":"reflect(input_ray, normal)","text":"

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • input_ray \u2013
           A ray or rays.\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n
  • normal \u2013
           A surface normal(s).\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n

Returns:

  • output_ray ( tensor ) \u2013

    Array that contains starting points and cosines of a reflected ray. Expected size is [1 x 2 x 3] or [m x 2 x 3].

Source code in odak/learn/raytracing/boundary.py
def reflect(input_ray, normal):\n    \"\"\" \n    Definition to reflect an incoming ray from a surface defined by a surface normal. \n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    input_ray    : torch.tensor\n                   A ray or rays.\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n    normal       : torch.tensor\n                   A surface normal(s).\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    Returns\n    ----------\n    output_ray   : torch.tensor\n                   Array that contains starting points and cosines of a reflected ray.\n                   Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    \"\"\"\n    if len(input_ray.shape) == 2:\n        input_ray = input_ray.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    mu = 1\n    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8\n    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div\n    a = a.unsqueeze(1)\n    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))\n    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)\n    output_ray[:, 0] = normal[:, 0]\n    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]\n    return output_ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.refract","title":"refract(vector, normvector, n1, n2, error=0.01)","text":"

Definition to refract an incoming ray. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • vector \u2013
             Incoming ray.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n
  • normvector \u2013
             Normal vector.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n
  • n1 \u2013
             Refractive index of the incoming medium.\n
  • n2 \u2013
             Refractive index of the outgoing medium.\n
  • error \u2013
             Desired error.\n

Returns:

  • output ( tensor ) \u2013

    Refracted ray. Expected size is [1, 2, 3]

Source code in odak/learn/raytracing/boundary.py
def refract(vector, normvector, n1, n2, error = 0.01):\n    \"\"\"\n    Definition to refract an incoming ray.\n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    vector         : torch.tensor\n                     Incoming ray.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n    normvector     : torch.tensor\n                     Normal vector.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n    n1             : float\n                     Refractive index of the incoming medium.\n    n2             : float\n                     Refractive index of the outgoing medium.\n    error          : float \n                     Desired error.\n\n    Returns\n    -------\n    output         : torch.tensor\n                     Refracted ray.\n                     Expected size is [1, 2, 3]\n    \"\"\"\n    if len(vector.shape) == 2:\n        vector = vector.unsqueeze(0)\n    if len(normvector.shape) == 2:\n        normvector = normvector.unsqueeze(0)\n    mu    = n1 / n2\n    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2\n    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div\n    b     = (mu ** 2 - 1) / div\n    to    = - b * 0.5 / a\n    num   = 0\n    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2\n    while len(eps[eps > error]) > 0:\n       num   += 1\n       oldto  = to\n       v      = to ** 2 + 2 * a * to + b\n       deltav = 2 * (to + a)\n       to     = to - v / deltav\n       eps    = abs(oldto - to)\n    output = torch.zeros_like(vector)\n    output[:, 0, 0] = normvector[:, 0, 0]\n    output[:, 0, 1] = normvector[:, 0, 1]\n    output[:, 0, 2] = normvector[:, 0, 2]\n    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]\n    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]\n    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]\n    return output\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.rotate_points","title":"rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]]))","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point with size of [3] or [1, 3] or [m, 3].\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis.\n       There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n       Expected size is [3] or [1, 3].\n
  • offset \u2013
           Shift with the given offset.\n       Expected size is [3] or [1, 3] or [m, 3].\n

Returns:

  • result ( tensor ) \u2013

    Result of the rotation [1 x 3] or [m x 3].

  • rotx ( tensor ) \u2013

    Rotation matrix along X axis [3 x 3].

  • roty ( tensor ) \u2013

    Rotation matrix along Y axis [3 x 3].

  • rotz ( tensor ) \u2013

    Rotation matrix along Z axis [3 x 3].

Source code in odak/learn/tools/transformation.py
def rotate_points(\n                 point,\n                 angles = torch.tensor([[0, 0, 0]]), \n                 mode='XYZ', \n                 origin = torch.tensor([[0, 0, 0]]), \n                 offset = torch.tensor([[0, 0, 0]])\n                ):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : torch.tensor\n                   A point with size of [3] or [1, 3] or [m, 3].\n    angles       : torch.tensor\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis.\n                   There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : torch.tensor\n                   Reference point for a rotation.\n                   Expected size is [3] or [1, 3].\n    offset       : torch.tensor\n                   Shift with the given offset.\n                   Expected size is [3] or [1, 3] or [m, 3].\n\n    Returns\n    ----------\n    result       : torch.tensor\n                   Result of the rotation [1 x 3] or [m x 3].\n    rotx         : torch.tensor\n                   Rotation matrix along X axis [3 x 3].\n    roty         : torch.tensor\n                   Rotation matrix along Y axis [3 x 3].\n    rotz         : torch.tensor\n                   Rotation matrix along Z axis [3 x 3].\n    \"\"\"\n    origin = origin.to(point.device)\n    offset = offset.to(point.device)\n    if len(point.shape) == 1:\n        point = point.unsqueeze(0)\n    if len(angles.shape) == 1:\n        angles = angles.unsqueeze(0)\n    rotx = rotmatx(angles[:, 0])\n    roty = rotmaty(angles[:, 1])\n    rotz = rotmatz(angles[:, 2])\n    new_point = (point - origin).T\n    if mode == 'XYZ':\n        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T\n    elif mode == 'XZY':\n        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T\n    elif mode == 'YXZ':\n        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T\n    elif mode == 'ZXY':\n        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T\n    elif mode == 'ZYX':\n        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T\n    result += origin\n    result += offset\n    return result, rotx, roty, rotz\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/learn/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = torch.subtract(b, a)\n    p1a = torch.subtract(p1, a)\n    p2a = torch.subtract(p2, a)\n    cp1 = torch.cross(ba, p1a)\n    cp2 = torch.cross(ba, p2a)\n    test = torch.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.save_torch_tensor","title":"save_torch_tensor(fn, tensor)","text":"

Definition to save a torch tensor.

Parameters:

  • fn \u2013
           Filename.\n
  • tensor \u2013
           Torch tensor to be saved.\n
Source code in odak/learn/tools/file.py
def save_torch_tensor(fn, tensor):\n    \"\"\"\n    Definition to save a torch tensor.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    tensor       : torch.tensor\n                   Torch tensor to be saved.\n    \"\"\" \n    torch.save(tensor, expanduser(fn))\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.write_PLY","title":"write_PLY(triangles, savefn='output.ply')","text":"

Definition to generate a PLY file from given points.

Parameters:

  • triangles \u2013
          List of triangles with the size of Mx3x3.\n
  • savefn \u2013
          Filename for a PLY file.\n
Source code in odak/tools/asset.py
def write_PLY(triangles, savefn = 'output.ply'):\n    \"\"\"\n    Definition to generate a PLY file from given points.\n\n    Parameters\n    ----------\n    triangles   : ndarray\n                  List of triangles with the size of Mx3x3.\n    savefn      : string\n                  Filename for a PLY file.\n    \"\"\"\n    tris = []\n    pnts = []\n    color = [255, 255, 255]\n    for tri_id in range(triangles.shape[0]):\n        tris.append(\n            (\n                [3*tri_id, 3*tri_id+1, 3*tri_id+2],\n                color[0],\n                color[1],\n                color[2]\n            )\n        )\n        for i in range(0, 3):\n            pnts.append(\n                (\n                    float(triangles[tri_id][i][0]),\n                    float(triangles[tri_id][i][1]),\n                    float(triangles[tri_id][i][2])\n                )\n            )\n    tris = np.asarray(tris, dtype=[\n                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    # Save mesh.\n    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])\n    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])\n    PlyData([el1, el2], text=\"True\").write(savefn)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.get_sphere_normal_torch","title":"get_sphere_normal_torch(point, sphere)","text":"

Definition to get a normal of a point on a given sphere.

Parameters:

  • point \u2013
            Point on sphere in X,Y,Z.\n
  • sphere \u2013
            Center defined in X,Y,Z and radius.\n

Returns:

  • normal_vector ( tensor ) \u2013

    Normal vector.

Source code in odak/learn/raytracing/boundary.py
def get_sphere_normal_torch(point, sphere):\n    \"\"\"\n    Definition to get a normal of a point on a given sphere.\n\n    Parameters\n    ----------\n    point         : torch.tensor\n                    Point on sphere in X,Y,Z.\n    sphere        : torch.tensor\n                    Center defined in X,Y,Z and radius.\n\n    Returns\n    ----------\n    normal_vector : torch.tensor\n                    Normal vector.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    normal_vector = create_ray_from_two_points(point, sphere[0:3])\n    return normal_vector\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.get_triangle_normal","title":"get_triangle_normal(triangle, triangle_center=None)","text":"

Definition to calculate surface normal of a triangle.

Parameters:

  • triangle \u2013
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n
  • triangle_center (tensor, default: None ) \u2013
              Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection.

Source code in odak/learn/raytracing/boundary.py
def get_triangle_normal(triangle, triangle_center=None):\n    \"\"\"\n    Definition to calculate surface normal of a triangle.\n\n    Parameters\n    ----------\n    triangle        : torch.tensor\n                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n    triangle_center : torch.tensor\n                      Center point of the given triangle. See odak.learn.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n\n    Returns\n    ----------\n    normal          : torch.tensor\n                      Surface normal at the point of intersection.\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.view((1, 3, 3))\n    normal = torch.zeros((triangle.shape[0], 2, 3)).to(triangle.device)\n    direction = torch.linalg.cross(\n                                   triangle[:, 0] - triangle[:, 1], \n                                   triangle[:, 2] - triangle[:, 1]\n                                  )\n    if type(triangle_center) == type(None):\n        normal[:, 0] = center_of_triangle(triangle)\n    else:\n        normal[:, 0] = triangle_center\n    normal[:, 1] = direction / torch.sum(direction, axis=1)[0]\n    if normal.shape[0] == 1:\n        normal = normal.view((2, 3))\n    return normal\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_circle","title":"intersect_w_circle(ray, circle)","text":"

Definition to find intersection point of a ray with a circle. Returns distance as zero if there isn't an intersection.

Parameters:

  • ray \u2013
           A vector/ray.\n
  • circle \u2013
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n

Returns:

  • normal ( Tensor ) \u2013

    Surface normal at the point of intersection.

  • distance ( Tensor ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_circle(ray, circle):\n    \"\"\"\n    Definition to find intersection point of a ray with a circle. \n    Returns distance as zero if there isn't an intersection.\n\n    Parameters\n    ----------\n    ray          : torch.Tensor\n                   A vector/ray.\n    circle       : list\n                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n\n    Returns\n    ----------\n    normal       : torch.Tensor\n                   Surface normal at the point of intersection.\n    distance     : torch.Tensor\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, circle[0])\n\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n\n    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])\n    mask = distance_to_center > circle[2]\n    distance[mask] = 0\n\n    if len(ray.shape) == 2:\n        normal = normal.squeeze(0)\n\n    return normal, distance\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_sphere","title":"intersect_w_sphere(ray, sphere, learning_rate=0.2, number_of_steps=5000, error_threshold=0.01)","text":"

Definition to find the intersection between ray(s) and sphere(s).

Parameters:

  • ray \u2013
                  Input ray(s).\n              Expected size is [1 x 2 x 3] or [m x 2 x 3].\n
  • sphere \u2013
                  Input sphere.\n              Expected size is [1 x 4].\n
  • learning_rate \u2013
                  Learning rate used in the optimizer for finding the propagation distances of the rays.\n
  • number_of_steps \u2013
                  Number of steps used in the optimizer.\n
  • error_threshold \u2013
                  The error threshold that will help deciding intersection or no intersection.\n

Returns:

  • intersecting_ray ( tensor ) \u2013

    Ray(s) that intersecting with the given sphere. Expected size is [n x 2 x 3], where n could be any real number.

  • intersecting_normal ( tensor ) \u2013

    Normal(s) for the ray(s) intersecting with the given sphere Expected size is [n x 2 x 3], where n could be any real number.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_sphere(ray, sphere, learning_rate = 2e-1, number_of_steps = 5000, error_threshold = 1e-2):\n    \"\"\"\n    Definition to find the intersection between ray(s) and sphere(s).\n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          Input ray(s).\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    sphere              : torch.tensor\n                          Input sphere.\n                          Expected size is [1 x 4].\n    learning_rate       : float\n                          Learning rate used in the optimizer for finding the propagation distances of the rays.\n    number_of_steps     : int\n                          Number of steps used in the optimizer.\n    error_threshold     : float\n                          The error threshold that will help deciding intersection or no intersection.\n\n    Returns\n    -------\n    intersecting_ray    : torch.tensor\n                          Ray(s) that intersecting with the given sphere.\n                          Expected size is [n x 2 x 3], where n could be any real number.\n    intersecting_normal : torch.tensor\n                          Normal(s) for the ray(s) intersecting with the given sphere\n                          Expected size is [n x 2 x 3], where n could be any real number.\n\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(sphere.shape) == 1:\n        sphere = sphere.unsqueeze(0)\n    distance = torch.zeros(ray.shape[0], device = ray.device, requires_grad = True)\n    loss_l2 = torch.nn.MSELoss(reduction = 'sum')\n    optimizer = torch.optim.AdamW([distance], lr = learning_rate)    \n    t = tqdm(range(number_of_steps), leave = False, dynamic_ncols = True)\n    for step in t:\n        optimizer.zero_grad()\n        propagated_ray = propagate_ray(ray, distance)\n        test = torch.abs((propagated_ray[:, 0, 0] - sphere[:, 0]) ** 2 + (propagated_ray[:, 0, 1] - sphere[:, 1]) ** 2 + (propagated_ray[:, 0, 2] - sphere[:, 2]) ** 2 - sphere[:, 3] ** 2)\n        loss = loss_l2(\n                       test,\n                       torch.zeros_like(test)\n                      )\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description('Sphere intersection loss: {}'.format(loss.item()))\n    check = test < error_threshold\n    intersecting_ray = propagate_ray(ray[check == True], distance[check == True])\n    intersecting_normal = create_ray_from_two_points(\n                                                     sphere[:, 0:3],\n                                                     intersecting_ray[:, 0]\n                                                    )\n    return intersecting_ray, intersecting_normal, distance, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_surface","title":"intersect_w_surface(ray, points)","text":"

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

Parameters:

  • ray \u2013
           A vector/ray.\n
  • points \u2013
           Set of points in X,Y and Z to define a planar surface.\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_surface(ray, points):\n    \"\"\"\n    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray.\n    points       : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface.\n\n    Returns\n    ----------\n    normal       : torch.tensor\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between starting point of a ray with it's intersection with a planar surface.\n    \"\"\"\n    normal = get_triangle_normal(points)\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(points.shape) == 2:\n        points = points.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    f = normal[:, 0] - ray[:, 0]\n    distance = (torch.mm(normal[:, 1], f.T) / torch.mm(normal[:, 1], ray[:, 1].T)).T\n    new_normal = torch.zeros_like(ray)\n    new_normal[:, 0] = ray[:, 0] + distance * ray[:, 1]\n    new_normal[:, 1] = normal[:, 1]\n    new_normal = torch.nan_to_num(\n                                  new_normal,\n                                  nan = float('nan'),\n                                  posinf = float('nan'),\n                                  neginf = float('nan')\n                                 )\n    distance = torch.nan_to_num(\n                                distance,\n                                nan = float('nan'),\n                                posinf = float('nan'),\n                                neginf = float('nan')\n                               )\n    return new_normal, distance\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_surface_batch","title":"intersect_w_surface_batch(ray, triangle)","text":"

Parameters:

  • ray \u2013
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection (m x n x 2 x 3).

  • distance ( tensor ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

Source code in odak/learn/raytracing/boundary.py
def intersect_w_surface_batch(ray, triangle):\n    \"\"\"\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal       : torch.tensor\n                   Surface normal at the point of intersection (m x n x 2 x 3).\n    distance     : torch.tensor\n                   Distance in between starting point of a ray with it's intersection with a planar surface (m x n).\n    \"\"\"\n    normal = get_triangle_normal(triangle)\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n\n    f = normal[:, None, 0] - ray[None, :, 0]\n    distance = (torch.bmm(normal[:, None, 1], f.permute(0, 2, 1)).squeeze(1) / torch.mm(normal[:, 1], ray[:, 1].T)).T\n\n    new_normal = torch.zeros((triangle.shape[0], )+ray.shape)\n    new_normal[:, :, 0] = ray[None, :, 0] + (distance[:, :, None] * ray[:, None, 1]).permute(1, 0, 2)\n    new_normal[:, :, 1] = normal[:, None, 1]\n    new_normal = torch.nan_to_num(\n                                  new_normal,\n                                  nan = float('nan'),\n                                  posinf = float('nan'),\n                                  neginf = float('nan')\n                                 )\n    distance = torch.nan_to_num(\n                                distance,\n                                nan = float('nan'),\n                                posinf = float('nan'),\n                                neginf = float('nan')\n                               )\n    return new_normal, distance.T\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_triangle","title":"intersect_w_triangle(ray, triangle)","text":"

Definition to find intersection point of a ray with a triangle.

Parameters:

  • ray \u2013
                  A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n
  • triangle \u2013
                  Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection with the surface of triangle. This could also involve surface normals that are not on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle. Expected size is [1 x 1] or [m x 1] depending on the input.

  • intersecting_ray ( tensor ) \u2013

    Rays that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • intersecting_normal ( tensor ) \u2013

    Normals that intersect with the triangle plane and on the triangle. Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.

  • check ( tensor ) \u2013

    A list that provides a bool as True or False for each ray used as input. A test to see is a ray could be on the given triangle. Expected size is [1] or [m].

Source code in odak/learn/raytracing/boundary.py
def intersect_w_triangle(ray, triangle):\n    \"\"\"\n    Definition to find intersection point of a ray with a triangle. \n\n    Parameters\n    ----------\n    ray                 : torch.tensor\n                          A ray [1 x 2 x 3] or a batch of ray [m x 2 x 3].\n    triangle            : torch.tensor\n                          Set of points in X,Y and Z to define a single triangle [1 x 3 x 3].\n\n    Returns\n    ----------\n    normal              : torch.tensor\n                          Surface normal at the point of intersection with the surface of triangle.\n                          This could also involve surface normals that are not on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    distance            : float\n                          Distance in between a starting point of a ray and the intersection point with a given triangle.\n                          Expected size is [1 x 1] or [m x 1] depending on the input.\n    intersecting_ray    : torch.tensor\n                          Rays that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    intersecting_normal : torch.tensor\n                          Normals that intersect with the triangle plane and on the triangle.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3] depending on the input.\n    check               : torch.tensor\n                          A list that provides a bool as True or False for each ray used as input.\n                          A test to see is a ray could be on the given triangle.\n                          Expected size is [1] or [m].\n    \"\"\"\n    if len(triangle.shape) == 2:\n       triangle = triangle.unsqueeze(0)\n    if len(ray.shape) == 2:\n       ray = ray.unsqueeze(0)\n    normal, distance = intersect_w_surface(ray, triangle)\n    check = is_it_on_triangle(normal[:, 0], triangle)\n    intersecting_ray = ray.unsqueeze(0)\n    intersecting_ray = intersecting_ray.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_ray = intersecting_ray[check == True]\n    intersecting_normal = normal.unsqueeze(0)\n    intersecting_normal = intersecting_normal.repeat(triangle.shape[0], 1, 1, 1)\n    intersecting_normal = intersecting_normal[check ==  True]\n    return normal, distance, intersecting_ray, intersecting_normal, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.intersect_w_triangle_batch","title":"intersect_w_triangle_batch(ray, triangle)","text":"

Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.

Parameters:

  • ray \u2013
           vectors/rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define triangles (m x 3 x 3).\n

Returns:

  • normal ( tensor ) \u2013

    Surface normal at the point of intersection (m x n x 2 x 3).

  • distance ( List ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface (m x n).

  • intersect_ray ( List ) \u2013

    List of intersecting rays (k x 2 x 3) where k <= n.

  • intersect_normal ( List ) \u2013

    List of intersecting normals (k x 2 x 3) where k <= n*m.

  • check ( tensor ) \u2013

    Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.

Source code in odak/learn/raytracing/boundary.py
def intersect_w_triangle_batch(ray, triangle):\n    \"\"\"\n    Definition to find intersection points of rays with triangles. Returns False for each variable if the rays doesn't intersect with given triangles.\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   vectors/rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal          : torch.tensor\n                      Surface normal at the point of intersection (m x n x 2 x 3).\n    distance        : List\n                      Distance in between starting point of a ray with it's intersection with a planar surface (m x n).\n    intersect_ray   : List\n                      List of intersecting rays (k x 2 x 3) where k <= n.\n    intersect_normal: List\n                      List of intersecting normals (k x 2 x 3) where k <= n*m.\n    check           : torch.tensor\n                      Boolean tensor (m x n) indicating whether each ray intersects with a triangle or not.\n    \"\"\"\n    if len(triangle.shape) == 2:\n       triangle = triangle.unsqueeze(0)\n    if len(ray.shape) == 2:\n       ray = ray.unsqueeze(0)\n\n    normal, distance = intersect_w_surface_batch(ray, triangle)\n\n    check = is_it_on_triangle_batch(normal[:, :, 0], triangle)\n\n    flat_check = check.flatten()\n    flat_normal = normal.view(-1, normal.size(-2), normal.size(-1))\n    flat_ray = ray.repeat(normal.size(0), 1, 1)\n    flat_distance = distance.flatten()\n\n    filtered_normal = torch.masked_select(flat_normal, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))\n    filtered_ray = torch.masked_select(flat_ray, flat_check.unsqueeze(-1).unsqueeze(-1).repeat(1, 2, 3))\n    filtered_distnace = torch.masked_select(flat_distance, flat_check)\n\n    check_count = check.sum(dim=1).tolist()\n    split_size_ray_and_normal = [count * 2 * 3 for count in check_count]\n    split_size_distance = [count for count in check_count]\n\n    normal_grouped = torch.split(filtered_normal, split_size_ray_and_normal)\n    ray_grouped = torch.split(filtered_ray, split_size_ray_and_normal)\n    distance_grouped = torch.split(filtered_distnace, split_size_distance)\n\n    intersecting_normal = [g.view(-1, 2, 3) for g in normal_grouped if g.numel() > 0]\n    intersecting_ray = [g.view(-1, 2, 3) for g in ray_grouped if g.numel() > 0]\n    new_distance = [g for g in distance_grouped if g.numel() > 0]\n\n    return normal, new_distance, intersecting_ray, intersecting_normal, check\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.reflect","title":"reflect(input_ray, normal)","text":"

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • input_ray \u2013
           A ray or rays.\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n
  • normal \u2013
           A surface normal(s).\n       Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n

Returns:

  • output_ray ( tensor ) \u2013

    Array that contains starting points and cosines of a reflected ray. Expected size is [1 x 2 x 3] or [m x 2 x 3].

Source code in odak/learn/raytracing/boundary.py
def reflect(input_ray, normal):\n    \"\"\" \n    Definition to reflect an incoming ray from a surface defined by a surface normal. \n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    input_ray    : torch.tensor\n                   A ray or rays.\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n    normal       : torch.tensor\n                   A surface normal(s).\n                   Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    Returns\n    ----------\n    output_ray   : torch.tensor\n                   Array that contains starting points and cosines of a reflected ray.\n                   Expected size is [1 x 2 x 3] or [m x 2 x 3].\n    \"\"\"\n    if len(input_ray.shape) == 2:\n        input_ray = input_ray.unsqueeze(0)\n    if len(normal.shape) == 2:\n        normal = normal.unsqueeze(0)\n    mu = 1\n    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2 + 1e-8\n    a = mu * (input_ray[:, 1, 0] * normal[:, 1, 0] + input_ray[:, 1, 1] * normal[:, 1, 1] + input_ray[:, 1, 2] * normal[:, 1, 2]) / div\n    a = a.unsqueeze(1)\n    n = int(torch.amax(torch.tensor([normal.shape[0], input_ray.shape[0]])))\n    output_ray = torch.zeros((n, 2, 3)).to(input_ray.device)\n    output_ray[:, 0] = normal[:, 0]\n    output_ray[:, 1] = input_ray[:, 1] - 2 * a * normal[:, 1]\n    return output_ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.boundary.refract","title":"refract(vector, normvector, n1, n2, error=0.01)","text":"

Definition to refract an incoming ray. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • vector \u2013
             Incoming ray.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n
  • normvector \u2013
             Normal vector.\n         Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n
  • n1 \u2013
             Refractive index of the incoming medium.\n
  • n2 \u2013
             Refractive index of the outgoing medium.\n
  • error \u2013
             Desired error.\n

Returns:

  • output ( tensor ) \u2013

    Refracted ray. Expected size is [1, 2, 3]

Source code in odak/learn/raytracing/boundary.py
def refract(vector, normvector, n1, n2, error = 0.01):\n    \"\"\"\n    Definition to refract an incoming ray.\n    Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n\n    Parameters\n    ----------\n    vector         : torch.tensor\n                     Incoming ray.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3].\n    normvector     : torch.tensor\n                     Normal vector.\n                     Expected size is [2, 3], [1, 2, 3] or [m, 2, 3]].\n    n1             : float\n                     Refractive index of the incoming medium.\n    n2             : float\n                     Refractive index of the outgoing medium.\n    error          : float \n                     Desired error.\n\n    Returns\n    -------\n    output         : torch.tensor\n                     Refracted ray.\n                     Expected size is [1, 2, 3]\n    \"\"\"\n    if len(vector.shape) == 2:\n        vector = vector.unsqueeze(0)\n    if len(normvector.shape) == 2:\n        normvector = normvector.unsqueeze(0)\n    mu    = n1 / n2\n    div   = normvector[:, 1, 0] ** 2  + normvector[:, 1, 1] ** 2 + normvector[:, 1, 2] ** 2\n    a     = mu * (vector[:, 1, 0] * normvector[:, 1, 0] + vector[:, 1, 1] * normvector[:, 1, 1] + vector[:, 1, 2] * normvector[:, 1, 2]) / div\n    b     = (mu ** 2 - 1) / div\n    to    = - b * 0.5 / a\n    num   = 0\n    eps   = torch.ones(vector.shape[0], device = vector.device) * error * 2\n    while len(eps[eps > error]) > 0:\n       num   += 1\n       oldto  = to\n       v      = to ** 2 + 2 * a * to + b\n       deltav = 2 * (to + a)\n       to     = to - v / deltav\n       eps    = abs(oldto - to)\n    output = torch.zeros_like(vector)\n    output[:, 0, 0] = normvector[:, 0, 0]\n    output[:, 0, 1] = normvector[:, 0, 1]\n    output[:, 0, 2] = normvector[:, 0, 2]\n    output[:, 1, 0] = mu * vector[:, 1, 0] + to * normvector[:, 1, 0]\n    output[:, 1, 1] = mu * vector[:, 1, 1] + to * normvector[:, 1, 1]\n    output[:, 1, 2] = mu * vector[:, 1, 2] + to * normvector[:, 1, 2]\n    return output\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.__init__","title":"__init__(colors=3, center=torch.tensor([0.0, 0.0, 0.0]), tilt=torch.tensor([0.0, 0.0, 0.0]), size=torch.tensor([10.0, 10.0]), resolution=torch.tensor([100, 100]), device=torch.device('cpu'))","text":"

Parameters:

  • colors \u2013
             Number of color channels to register (e.g., RGB).\n
  • center \u2013
             Center point of the detector [3].\n
  • tilt \u2013
             Tilt angles of the surface in degrees [3].\n
  • size \u2013
             Size of the detector [2].\n
  • resolution \u2013
             Resolution of the detector.\n
  • device \u2013
             Device for computation (e.g., cuda, cpu).\n
Source code in odak/learn/raytracing/detector.py
def __init__(\n             self,\n             colors = 3,\n             center = torch.tensor([0., 0., 0.]),\n             tilt = torch.tensor([0., 0., 0.]),\n             size = torch.tensor([10., 10.]),\n             resolution = torch.tensor([100, 100]),\n             device = torch.device('cpu')\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    colors         : int\n                     Number of color channels to register (e.g., RGB).\n    center         : torch.tensor\n                     Center point of the detector [3].\n    tilt           : torch.tensor\n                     Tilt angles of the surface in degrees [3].\n    size           : torch.tensor\n                     Size of the detector [2].\n    resolution     : torch.tensor\n                     Resolution of the detector.\n    device         : torch.device\n                     Device for computation (e.g., cuda, cpu).\n    \"\"\"\n    self.device = device\n    self.colors = colors\n    self.resolution = resolution.to(self.device)\n    self.surface_center = center.to(self.device)\n    self.surface_tilt = tilt.to(self.device)\n    self.size = size.to(self.device)\n    self.pixel_size = torch.tensor([\n                                    self.size[0] / self.resolution[0],\n                                    self.size[1] / self.resolution[1]\n                                   ], device  = self.device)\n    self.pixel_diagonal_size = torch.sqrt(self.pixel_size[0] ** 2 + self.pixel_size[1] ** 2)\n    self.pixel_diagonal_half_size = self.pixel_diagonal_size / 2.\n    self.threshold = torch.nn.Threshold(self.pixel_diagonal_size, 1)\n    self.plane = define_plane(\n                              point = self.surface_center,\n                              angles = self.surface_tilt\n                             )\n    self.pixel_locations, _, _, _ = grid_sample(\n                                                size = self.size.tolist(),\n                                                no = self.resolution.tolist(),\n                                                center = self.surface_center.tolist(),\n                                                angles = self.surface_tilt.tolist()\n                                               )\n    self.pixel_locations = self.pixel_locations.to(self.device)\n    self.relu = torch.nn.ReLU()\n    self.clear()\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.clear","title":"clear()","text":"

Internal function to clear a detector.

Source code in odak/learn/raytracing/detector.py
def clear(self):\n    \"\"\"\n    Internal function to clear a detector.\n    \"\"\"\n    self.image = torch.zeros(\n\n                             self.colors,\n                             self.resolution[0],\n                             self.resolution[1],\n                             device = self.device,\n                            )\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.get_image","title":"get_image()","text":"

Function to return the detector image.

Returns:

  • image ( tensor ) \u2013

    Detector image.

Source code in odak/learn/raytracing/detector.py
def get_image(self):\n    \"\"\"\n    Function to return the detector image.\n\n    Returns\n    -------\n    image           : torch.tensor\n                      Detector image.\n    \"\"\"\n    image = (self.image - self.image.min()) / (self.image.max() - self.image.min())\n    return image\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.detector.intersect","title":"intersect(rays, color=0)","text":"

Function to intersect rays with the detector

Parameters:

  • rays \u2013
              Rays to be intersected with a detector.\n          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n
  • color \u2013
              Color channel to register.\n

Returns:

  • points ( tensor ) \u2013

    Intersection points with the image detector [k x 3].

Source code in odak/learn/raytracing/detector.py
    def intersect(self, rays, color = 0):\n        \"\"\"\n        Function to intersect rays with the detector\n\n\n        Parameters\n        ----------\n        rays            : torch.tensor\n                          Rays to be intersected with a detector.\n                          Expected size is [1 x 2 x 3] or [m x 2 x 3].\n        color           : int\n                          Color channel to register.\n\n        Returns\n        -------\n        points          : torch.tensor\n                          Intersection points with the image detector [k x 3].\n        \"\"\"\n        normals, _ = intersect_w_surface(rays, self.plane)\n        points = normals[:, 0]\n        distances_xyz = torch.abs(points.unsqueeze(1) - self.pixel_locations.unsqueeze(0))\n        distances_x = 1e6 * self.relu( - (distances_xyz[:, :, 0] - self.pixel_size[0]))\n        distances_y = 1e6 * self.relu( - (distances_xyz[:, :, 1] - self.pixel_size[1]))\n        hit_x = torch.clamp(distances_x, min = 0., max = 1.)\n        hit_y = torch.clamp(distances_y, min = 0., max = 1.)\n        hit = hit_x * hit_y\n        image = torch.sum(hit, dim = 0)\n        self.image[color] += image.reshape(\n                                           self.image.shape[-2], \n                                           self.image.shape[-1]\n                                          )\n        distances = torch.sum((points.unsqueeze(1) - self.pixel_locations.unsqueeze(0)) ** 2, dim = 2)\n        distance_image = distances\n#        distance_image = distances.reshape(\n#                                           -1,\n#                                           self.image.shape[-2],\n#                                           self.image.shape[-1]\n#                                          )\n        return points, image, distance_image\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh","title":"planar_mesh","text":"Source code in odak/learn/raytracing/mesh.py
class planar_mesh():\n\n\n    def __init__(\n                 self,\n                 size = [1., 1.],\n                 number_of_meshes = [10, 10],\n                 angles = torch.tensor([0., 0., 0.]),\n                 offset = torch.tensor([0., 0., 0.]),\n                 device = torch.device('cpu'),\n                 heights = None\n                ):\n        \"\"\"\n        Definition to generate a plane with meshes.\n\n\n        Parameters\n        -----------\n        number_of_meshes  : torch.tensor\n                            Number of squares over plane.\n                            There are two triangles at each square.\n        size              : torch.tensor\n                            Size of the plane.\n        angles            : torch.tensor\n                            Rotation angles in degrees.\n        offset            : torch.tensor\n                            Offset along XYZ axes.\n                            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n        device            : torch.device\n                            Computational resource to be used (e.g., cpu, cuda).\n        heights           : torch.tensor\n                            Load surface heights from a tensor.\n        \"\"\"\n        self.device = device\n        self.angles = angles.to(self.device)\n        self.offset = offset.to(self.device)\n        self.size = size.to(self.device)\n        self.number_of_meshes = number_of_meshes.to(self.device)\n        self.init_heights(heights)\n\n\n    def init_heights(self, heights = None):\n        \"\"\"\n        Internal function to initialize a height map.\n        Note that self.heights is a differentiable variable, and can be optimized or learned.\n        See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.\n        \"\"\"\n        if not isinstance(heights, type(None)):\n            self.heights = heights.to(self.device)\n            self.heights.requires_grad = True\n        else:\n            self.heights = torch.zeros(\n                                       (self.number_of_meshes[0], self.number_of_meshes[1], 1),\n                                       requires_grad = True,\n                                       device = self.device,\n                                      )\n        x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) \n        y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)\n        X, Y = torch.meshgrid(x, y, indexing = 'ij')\n        self.X = X.unsqueeze(-1)\n        self.Y = Y.unsqueeze(-1)\n\n\n    def save_heights(self, filename = 'heights.pt'):\n        \"\"\"\n        Function to save heights to a file.\n\n        Parameters\n        ----------\n        filename          : str\n                            Filename.\n        \"\"\"\n        save_torch_tensor(filename, self.heights.detach().clone())\n\n\n    def save_heights_as_PLY(self, filename = 'mesh.ply'):\n        \"\"\"\n        Function to save mesh to a PLY file.\n\n        Parameters\n        ----------\n        filename          : str\n                            Filename.\n        \"\"\"\n        triangles = self.get_triangles()\n        write_PLY(triangles, filename)\n\n\n    def get_squares(self):\n        \"\"\"\n        Internal function to initiate squares over a plane.\n\n        Returns\n        -------\n        squares     : torch.tensor\n                      Squares over a plane.\n                      Expected size is [m x n x 3].\n        \"\"\"\n        squares = torch.cat((\n                             self.X,\n                             self.Y,\n                             self.heights\n                            ), dim = -1)\n        return squares\n\n\n    def get_triangles(self):\n        \"\"\"\n        Internal function to get triangles.\n        \"\"\" \n        squares = self.get_squares()\n        triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)\n        for i in range(0, self.number_of_meshes[0] - 1):\n            for j in range(0, self.number_of_meshes[1] - 1):\n                first_triangle = torch.cat((\n                                            squares[i + 1, j].unsqueeze(0),\n                                            squares[i + 1, j + 1].unsqueeze(0),\n                                            squares[i, j + 1].unsqueeze(0),\n                                           ), dim = 0)\n                second_triangle = torch.cat((\n                                             squares[i + 1, j].unsqueeze(0),\n                                             squares[i, j + 1].unsqueeze(0),\n                                             squares[i, j].unsqueeze(0),\n                                            ), dim = 0)\n                triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)\n                triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)\n        triangles = triangles.view(-1, 3, 3) + self.offset\n        return triangles \n\n\n    def mirror(self, rays):\n        \"\"\"\n        Function to bounce light rays off the meshes.\n\n        Parameters\n        ----------\n        rays              : torch.tensor\n                            Rays to be bounced.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n        Returns\n        -------\n        reflected_rays    : torch.tensor\n                            Reflected rays.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n        reflected_normals : torch.tensor\n                            Reflected normals.\n                            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n        \"\"\"\n        if len(rays.shape) == 2:\n            rays = rays.unsqueeze(0)\n        triangles = self.get_triangles()\n        reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n        reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n        for triangle in triangles:\n            _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(\n                                                                                        rays,\n                                                                                        triangle\n                                                                                       ) \n            triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)\n            if triangle_reflected_rays.shape[0] > 0:\n                reflected_rays = torch.cat((\n                                            reflected_rays,\n                                            triangle_reflected_rays\n                                          ))\n                reflected_normals = torch.cat((\n                                               reflected_normals,\n                                               intersecting_normals\n                                              ))\n        return reflected_rays, reflected_normals\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.__init__","title":"__init__(size=[1.0, 1.0], number_of_meshes=[10, 10], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([0.0, 0.0, 0.0]), device=torch.device('cpu'), heights=None)","text":"

Definition to generate a plane with meshes.

Parameters:

  • number_of_meshes \u2013
                Number of squares over plane.\n            There are two triangles at each square.\n
  • size \u2013
                Size of the plane.\n
  • angles \u2013
                Rotation angles in degrees.\n
  • offset \u2013
                Offset along XYZ axes.\n            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n
  • device \u2013
                Computational resource to be used (e.g., cpu, cuda).\n
  • heights \u2013
                Load surface heights from a tensor.\n
Source code in odak/learn/raytracing/mesh.py
def __init__(\n             self,\n             size = [1., 1.],\n             number_of_meshes = [10, 10],\n             angles = torch.tensor([0., 0., 0.]),\n             offset = torch.tensor([0., 0., 0.]),\n             device = torch.device('cpu'),\n             heights = None\n            ):\n    \"\"\"\n    Definition to generate a plane with meshes.\n\n\n    Parameters\n    -----------\n    number_of_meshes  : torch.tensor\n                        Number of squares over plane.\n                        There are two triangles at each square.\n    size              : torch.tensor\n                        Size of the plane.\n    angles            : torch.tensor\n                        Rotation angles in degrees.\n    offset            : torch.tensor\n                        Offset along XYZ axes.\n                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n    device            : torch.device\n                        Computational resource to be used (e.g., cpu, cuda).\n    heights           : torch.tensor\n                        Load surface heights from a tensor.\n    \"\"\"\n    self.device = device\n    self.angles = angles.to(self.device)\n    self.offset = offset.to(self.device)\n    self.size = size.to(self.device)\n    self.number_of_meshes = number_of_meshes.to(self.device)\n    self.init_heights(heights)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.get_squares","title":"get_squares()","text":"

Internal function to initiate squares over a plane.

Returns:

  • squares ( tensor ) \u2013

    Squares over a plane. Expected size is [m x n x 3].

Source code in odak/learn/raytracing/mesh.py
def get_squares(self):\n    \"\"\"\n    Internal function to initiate squares over a plane.\n\n    Returns\n    -------\n    squares     : torch.tensor\n                  Squares over a plane.\n                  Expected size is [m x n x 3].\n    \"\"\"\n    squares = torch.cat((\n                         self.X,\n                         self.Y,\n                         self.heights\n                        ), dim = -1)\n    return squares\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.get_triangles","title":"get_triangles()","text":"

Internal function to get triangles.

Source code in odak/learn/raytracing/mesh.py
def get_triangles(self):\n    \"\"\"\n    Internal function to get triangles.\n    \"\"\" \n    squares = self.get_squares()\n    triangles = torch.zeros(2, self.number_of_meshes[0], self.number_of_meshes[1], 3, 3, device = self.device)\n    for i in range(0, self.number_of_meshes[0] - 1):\n        for j in range(0, self.number_of_meshes[1] - 1):\n            first_triangle = torch.cat((\n                                        squares[i + 1, j].unsqueeze(0),\n                                        squares[i + 1, j + 1].unsqueeze(0),\n                                        squares[i, j + 1].unsqueeze(0),\n                                       ), dim = 0)\n            second_triangle = torch.cat((\n                                         squares[i + 1, j].unsqueeze(0),\n                                         squares[i, j + 1].unsqueeze(0),\n                                         squares[i, j].unsqueeze(0),\n                                        ), dim = 0)\n            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = self.angles)\n            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = self.angles)\n    triangles = triangles.view(-1, 3, 3) + self.offset\n    return triangles \n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.init_heights","title":"init_heights(heights=None)","text":"

Internal function to initialize a height map. Note that self.heights is a differentiable variable, and can be optimized or learned. See unit test test/test_learn_ray_detector.py or test/test_learn_ray_mesh.py as examples.

Source code in odak/learn/raytracing/mesh.py
def init_heights(self, heights = None):\n    \"\"\"\n    Internal function to initialize a height map.\n    Note that self.heights is a differentiable variable, and can be optimized or learned.\n    See unit test `test/test_learn_ray_detector.py` or `test/test_learn_ray_mesh.py` as examples.\n    \"\"\"\n    if not isinstance(heights, type(None)):\n        self.heights = heights.to(self.device)\n        self.heights.requires_grad = True\n    else:\n        self.heights = torch.zeros(\n                                   (self.number_of_meshes[0], self.number_of_meshes[1], 1),\n                                   requires_grad = True,\n                                   device = self.device,\n                                  )\n    x = torch.linspace(-self.size[0] / 2., self.size[0] / 2., self.number_of_meshes[0], device = self.device) \n    y = torch.linspace(-self.size[1] / 2., self.size[1] / 2., self.number_of_meshes[1], device = self.device)\n    X, Y = torch.meshgrid(x, y, indexing = 'ij')\n    self.X = X.unsqueeze(-1)\n    self.Y = Y.unsqueeze(-1)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.mirror","title":"mirror(rays)","text":"

Function to bounce light rays off the meshes.

Parameters:

  • rays \u2013
                Rays to be bounced.\n            Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n

Returns:

  • reflected_rays ( tensor ) \u2013

    Reflected rays. Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

  • reflected_normals ( tensor ) \u2013

    Reflected normals. Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].

Source code in odak/learn/raytracing/mesh.py
def mirror(self, rays):\n    \"\"\"\n    Function to bounce light rays off the meshes.\n\n    Parameters\n    ----------\n    rays              : torch.tensor\n                        Rays to be bounced.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    Returns\n    -------\n    reflected_rays    : torch.tensor\n                        Reflected rays.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n    reflected_normals : torch.tensor\n                        Reflected normals.\n                        Expected size is [2 x 3], [1 x 2 x 3] or [m x 2 x 3].\n\n    \"\"\"\n    if len(rays.shape) == 2:\n        rays = rays.unsqueeze(0)\n    triangles = self.get_triangles()\n    reflected_rays = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n    reflected_normals = torch.empty((0, 2, 3), requires_grad = True, device = self.device)\n    for triangle in triangles:\n        _, _, intersecting_rays, intersecting_normals, check = intersect_w_triangle(\n                                                                                    rays,\n                                                                                    triangle\n                                                                                   ) \n        triangle_reflected_rays = reflect(intersecting_rays, intersecting_normals)\n        if triangle_reflected_rays.shape[0] > 0:\n            reflected_rays = torch.cat((\n                                        reflected_rays,\n                                        triangle_reflected_rays\n                                      ))\n            reflected_normals = torch.cat((\n                                           reflected_normals,\n                                           intersecting_normals\n                                          ))\n    return reflected_rays, reflected_normals\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.save_heights","title":"save_heights(filename='heights.pt')","text":"

Function to save heights to a file.

Parameters:

  • filename \u2013
                Filename.\n
Source code in odak/learn/raytracing/mesh.py
def save_heights(self, filename = 'heights.pt'):\n    \"\"\"\n    Function to save heights to a file.\n\n    Parameters\n    ----------\n    filename          : str\n                        Filename.\n    \"\"\"\n    save_torch_tensor(filename, self.heights.detach().clone())\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.mesh.planar_mesh.save_heights_as_PLY","title":"save_heights_as_PLY(filename='mesh.ply')","text":"

Function to save mesh to a PLY file.

Parameters:

  • filename \u2013
                Filename.\n
Source code in odak/learn/raytracing/mesh.py
def save_heights_as_PLY(self, filename = 'mesh.ply'):\n    \"\"\"\n    Function to save mesh to a PLY file.\n\n    Parameters\n    ----------\n    filename          : str\n                        Filename.\n    \"\"\"\n    triangles = self.get_triangles()\n    write_PLY(triangles, filename)\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.center_of_triangle","title":"center_of_triangle(triangle)","text":"

Definition to calculate center of a triangle.

Parameters:

  • triangle \u2013
            An array that contains three points defining a triangle (Mx3). \n        It can also parallel process many triangles (NxMx3).\n

Returns:

  • centers ( tensor ) \u2013

    Triangle centers.

Source code in odak/learn/raytracing/primitives.py
def center_of_triangle(triangle):\n    \"\"\"\n    Definition to calculate center of a triangle.\n\n    Parameters\n    ----------\n    triangle      : torch.tensor\n                    An array that contains three points defining a triangle (Mx3). \n                    It can also parallel process many triangles (NxMx3).\n\n    Returns\n    -------\n    centers       : torch.tensor\n                    Triangle centers.\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.view((1, 3, 3))\n    center = torch.mean(triangle, axis=1)\n    return center\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.define_circle","title":"define_circle(center, radius, angles)","text":"

Definition to describe a circle in a single variable packed form.

Parameters:

  • center \u2013
      Center of a circle to be defined in 3D space.\n
  • radius \u2013
      Radius of a circle to be defined.\n
  • angles \u2013
      Angular tilt of a circle represented by rotations about x, y, and z axes.\n

Returns:

  • circle ( list ) \u2013

    Single variable packed form.

Source code in odak/learn/raytracing/primitives.py
def define_circle(center, radius, angles):\n    \"\"\"\n    Definition to describe a circle in a single variable packed form.\n\n    Parameters\n    ----------\n    center  : torch.Tensor\n              Center of a circle to be defined in 3D space.\n    radius  : float\n              Radius of a circle to be defined.\n    angles  : torch.Tensor\n              Angular tilt of a circle represented by rotations about x, y, and z axes.\n\n    Returns\n    ----------\n    circle  : list\n              Single variable packed form.\n    \"\"\"\n    points = define_plane(center, angles=angles)\n    circle = [\n        points,\n        center,\n        torch.tensor([radius])\n    ]\n    return circle\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.define_plane","title":"define_plane(point, angles=torch.tensor([0.0, 0.0, 0.0]))","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • point \u2013
           A point that is at the center of a plane.\n
  • angles \u2013
           Rotation angles in degrees.\n

Returns:

  • plane ( tensor ) \u2013

    Points defining plane.

Source code in odak/learn/raytracing/primitives.py
def define_plane(point, angles = torch.tensor([0., 0., 0.])):\n    \"\"\" \n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    point        : torch.tensor\n                   A point that is at the center of a plane.\n    angles       : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    plane        : torch.tensor\n                   Points defining plane.\n    \"\"\"\n    plane = torch.tensor([\n                          [10., 10., 0.],\n                          [0., 10., 0.],\n                          [0.,  0., 0.]\n                         ], device = point.device)\n    for i in range(0, plane.shape[0]):\n        plane[i], _, _, _ = rotate_points(plane[i], angles = angles.to(point.device))\n        plane[i] = plane[i] + point\n    return plane\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.define_plane_mesh","title":"define_plane_mesh(number_of_meshes=[10, 10], size=[1.0, 1.0], angles=torch.tensor([0.0, 0.0, 0.0]), offset=torch.tensor([[0.0, 0.0, 0.0]]))","text":"

Definition to generate a plane with meshes.

Parameters:

  • number_of_meshes \u2013
                Number of squares over plane.\n            There are two triangles at each square.\n
  • size \u2013
                Size of the plane.\n
  • angles \u2013
                Rotation angles in degrees.\n
  • offset \u2013
                Offset along XYZ axes.\n            Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n            m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n

Returns:

  • triangles ( tensor ) \u2013

    Triangles [m x 3 x 3], where m is 2 * number_of_meshes[0] times number_of_meshes[1].

Source code in odak/learn/raytracing/primitives.py
def define_plane_mesh(\n                      number_of_meshes = [10, 10], \n                      size = [1., 1.], \n                      angles = torch.tensor([0., 0., 0.]), \n                      offset = torch.tensor([[0., 0., 0.]])\n                     ):\n    \"\"\"\n    Definition to generate a plane with meshes.\n\n\n    Parameters\n    -----------\n    number_of_meshes  : torch.tensor\n                        Number of squares over plane.\n                        There are two triangles at each square.\n    size              : list\n                        Size of the plane.\n    angles            : torch.tensor\n                        Rotation angles in degrees.\n    offset            : torch.tensor\n                        Offset along XYZ axes.\n                        Expected dimension is [1 x 3] or offset for each triangle [m x 3].\n                        m here refers to `2 * number_of_meshes[0]` times  `number_of_meshes[1]`. \n\n    Returns\n    -------\n    triangles         : torch.tensor\n                        Triangles [m x 3 x 3], where m is `2 * number_of_meshes[0]` times  `number_of_meshes[1]`.\n    \"\"\"\n    triangles = torch.zeros(2, number_of_meshes[0], number_of_meshes[1], 3, 3)\n    step = [size[0] / number_of_meshes[0], size[1] / number_of_meshes[1]]\n    for i in range(0, number_of_meshes[0] - 1):\n        for j in range(0, number_of_meshes[1] - 1):\n            first_triangle = torch.tensor([\n                                           [       -size[0] / 2. + step[0] * i,       -size[1] / 2. + step[0] * j, 0.],\n                                           [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],\n                                           [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]\n                                          ])\n            second_triangle = torch.tensor([\n                                            [ -size[0] / 2. + step[0] * (i + 1), -size[1] / 2. + step[0] * (j + 1), 0.],\n                                            [ -size[0] / 2. + step[0] * (i + 1),       -size[1] / 2. + step[0] * j, 0.],\n                                            [       -size[0] / 2. + step[0] * i, -size[1] / 2. + step[0] * (j + 1), 0.]\n                                           ])\n            triangles[0, i, j], _, _, _ = rotate_points(first_triangle, angles = angles)\n            triangles[1, i, j], _, _, _ = rotate_points(second_triangle, angles = angles)\n    triangles = triangles.view(-1, 3, 3) + offset\n    return triangles\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.define_sphere","title":"define_sphere(center=torch.tensor([[0.0, 0.0, 0.0]]), radius=torch.tensor([1.0]))","text":"

Definition to define a sphere.

Parameters:

  • center \u2013
          Center of the sphere(s) along XYZ axes.\n      Expected size is [3], [1, 3] or [m, 3].\n
  • radius \u2013
          Radius of that sphere(s).\n      Expected size is [1], [1, 1], [m] or [m, 1].\n

Returns:

  • parameters ( tensor ) \u2013

    Parameters of defined sphere(s). Expected size is [1, 3] or [m x 3].

Source code in odak/learn/raytracing/primitives.py
def define_sphere(center = torch.tensor([[0., 0., 0.]]), radius = torch.tensor([1.])):\n    \"\"\"\n    Definition to define a sphere.\n\n    Parameters\n    ----------\n    center      : torch.tensor\n                  Center of the sphere(s) along XYZ axes.\n                  Expected size is [3], [1, 3] or [m, 3].\n    radius      : torch.tensor\n                  Radius of that sphere(s).\n                  Expected size is [1], [1, 1], [m] or [m, 1].\n\n    Returns\n    -------\n    parameters  : torch.tensor\n                  Parameters of defined sphere(s).\n                  Expected size is [1, 3] or [m x 3].\n    \"\"\"\n    if len(radius.shape) == 1:\n        radius = radius.unsqueeze(0)\n    if len(center.shape) == 1:\n        center = center.unsqueeze(1)\n    parameters = torch.cat((center, radius), dim = 1)\n    return parameters\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.is_it_on_triangle","title":"is_it_on_triangle(point_to_check, triangle)","text":"

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True. For more details, visit: https://blackpawn.com/texts/pointinpoly/.

Parameters:

  • point_to_check \u2013
              Point(s) to check.\n          Expected size is [3], [1 x 3] or [m x 3].\n
  • triangle \u2013
              Triangle described with three points.\n          Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n

Returns:

  • result ( tensor ) \u2013

    Is it on a triangle? Returns NaN if condition not satisfied. Expected size is [1] or [m] depending on the input.

Source code in odak/learn/raytracing/primitives.py
def is_it_on_triangle(point_to_check, triangle):\n    \"\"\"\n    Definition to check if a given point is inside a triangle. \n    If the given point is inside a defined triangle, this definition returns True.\n    For more details, visit: [https://blackpawn.com/texts/pointinpoly/](https://blackpawn.com/texts/pointinpoly/).\n\n    Parameters\n    ----------\n    point_to_check  : torch.tensor\n                      Point(s) to check.\n                      Expected size is [3], [1 x 3] or [m x 3].\n    triangle        : torch.tensor\n                      Triangle described with three points.\n                      Expected size is [3 x 3], [1 x 3 x 3] or [m x 3 x3].\n\n    Returns\n    -------\n    result          : torch.tensor\n                      Is it on a triangle? Returns NaN if condition not satisfied.\n                      Expected size is [1] or [m] depending on the input.\n    \"\"\"\n    if len(point_to_check.shape) == 1:\n        point_to_check = point_to_check.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    v0 = triangle[:, 2] - triangle[:, 0]\n    v1 = triangle[:, 1] - triangle[:, 0]\n    v2 = point_to_check - triangle[:, 0]\n    if len(v0.shape) == 1:\n        v0 = v0.unsqueeze(0)\n    if len(v1.shape) == 1:\n        v1 = v1.unsqueeze(0)\n    if len(v2.shape) == 1:\n        v2 = v2.unsqueeze(0)\n    dot00 = torch.mm(v0, v0.T)\n    dot01 = torch.mm(v0, v1.T)\n    dot02 = torch.mm(v0, v2.T) \n    dot11 = torch.mm(v1, v1.T)\n    dot12 = torch.mm(v1, v2.T)\n    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)\n    u = (dot11 * dot02 - dot01 * dot12) * invDenom\n    v = (dot00 * dot12 - dot01 * dot02) * invDenom\n    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)\n    return result\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.primitives.is_it_on_triangle_batch","title":"is_it_on_triangle_batch(point_to_check, triangle)","text":"

Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.

Parameters:

  • point_to_check \u2013
              Points to check (m x n x 3).\n
  • triangle \u2013
              Triangles (m x 3 x 3).\n

Returns:

  • result ( torch.tensor (m x n) ) \u2013
Source code in odak/learn/raytracing/primitives.py
def is_it_on_triangle_batch(point_to_check, triangle):\n    \"\"\"\n    Definition to check if given points are inside triangles. If the given points are inside defined triangles, this definition returns True.\n\n    Parameters\n    ----------\n    point_to_check  : torch.tensor\n                      Points to check (m x n x 3).\n    triangle        : torch.tensor \n                      Triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    result          : torch.tensor (m x n)\n\n    \"\"\"\n    if len(point_to_check.shape) == 1:\n        point_to_check = point_to_check.unsqueeze(0)\n    if len(triangle.shape) == 2:\n        triangle = triangle.unsqueeze(0)\n    v0 = triangle[:, 2] - triangle[:, 0]\n    v1 = triangle[:, 1] - triangle[:, 0]\n    v2 = point_to_check - triangle[:, None, 0]\n    if len(v0.shape) == 1:\n        v0 = v0.unsqueeze(0)\n    if len(v1.shape) == 1:\n        v1 = v1.unsqueeze(0)\n    if len(v2.shape) == 1:\n        v2 = v2.unsqueeze(0)\n\n    dot00 = torch.bmm(v0.unsqueeze(1), v0.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot01 = torch.bmm(v0.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot02 = torch.bmm(v0.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)\n    dot11 = torch.bmm(v1.unsqueeze(1), v1.unsqueeze(1).permute(0, 2, 1)).squeeze(1)\n    dot12 = torch.bmm(v1.unsqueeze(1), v2.permute(0, 2, 1)).squeeze(1)\n    invDenom = 1. / (dot00 * dot11 - dot01 * dot01)\n    u = (dot11 * dot02 - dot01 * dot12) * invDenom\n    v = (dot00 * dot12 - dot01 * dot02) * invDenom\n    result = (u >= 0.) & (v >= 0.) & ((u + v) < 1)\n\n    return result\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.create_ray","title":"create_ray(xyz, abg, direction=False)","text":"

Definition to create a ray.

Parameters:

  • xyz \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • abg \u2013
           List that contains angles in degrees with respect to the X,Y and Z axes.\n       Size could be [1 x 3], [3], [m x 3].\n
  • direction \u2013
           If set to True, cosines of `abg` is not calculated.\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray. Size will be either [1 x 3] or [m x 3].

Source code in odak/learn/raytracing/ray.py
def create_ray(xyz, abg, direction = False):\n    \"\"\"\n    Definition to create a ray.\n\n    Parameters\n    ----------\n    xyz          : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    abg          : torch.tensor\n                   List that contains angles in degrees with respect to the X,Y and Z axes.\n                   Size could be [1 x 3], [3], [m x 3].\n    direction    : bool\n                   If set to True, cosines of `abg` is not calculated.\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray.\n                   Size will be either [1 x 3] or [m x 3].\n    \"\"\"\n    points = xyz\n    angles = abg\n    if len(xyz) == 1:\n        points = xyz.unsqueeze(0)\n    if len(abg) == 1:\n        angles = abg.unsqueeze(0)\n    ray = torch.zeros(points.shape[0], 2, 3, device = points.device)\n    ray[:, 0] = points\n    if direction:\n        ray[:, 1] = abg\n    else:\n        ray[:, 1] = torch.cos(torch.deg2rad(abg))\n    return ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.create_ray_from_all_pairs","title":"create_ray_from_all_pairs(x0y0z0, x1y1z1)","text":"

Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.

Parameters:

  • x0y0z0 \u2013
           Tensor that contains X, Y, and Z start locations of rays.\n       Size should be [m x 3].\n
  • x1y1z1 \u2013
           Tensor that contains X, Y, and Z end locations of rays.\n       Size should be [n x 3].\n

Returns:

  • rays ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_all_pairs(x0y0z0, x1y1z1):\n    \"\"\"\n    Creates rays from all possible pairs of points in x0y0z0 and x1y1z1.\n\n    Parameters\n    ----------\n    x0y0z0       : torch.tensor\n                   Tensor that contains X, Y, and Z start locations of rays.\n                   Size should be [m x 3].\n    x1y1z1       : torch.tensor\n                   Tensor that contains X, Y, and Z end locations of rays.\n                   Size should be [n x 3].\n\n    Returns\n    ----------\n    rays         : torch.tensor\n                   Array that contains starting points and cosines of a created ray(s). Size of [n*m x 2 x 3]\n    \"\"\"\n\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.unsqueeze(0)\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.unsqueeze(0)\n\n    m, n = x0y0z0.shape[0], x1y1z1.shape[0]\n    start_points = x0y0z0.unsqueeze(1).expand(-1, n, -1).reshape(-1, 3)\n    end_points = x1y1z1.unsqueeze(0).expand(m, -1, -1).reshape(-1, 3)\n\n    directions = end_points - start_points\n    norms = torch.norm(directions, p=2, dim=1, keepdim=True)\n    norms[norms == 0] = float('nan')\n\n    normalized_directions = directions / norms\n\n    rays = torch.zeros(m * n, 2, 3, device=x0y0z0.device)\n    rays[:, 0, :] = start_points\n    rays[:, 1, :] = normalized_directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.create_ray_from_grid_w_luminous_angle","title":"create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit)","text":"

Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.

Parameters:

center : torch.tensor The center point of the light array, shape [3]. size : list[int] The size of the light array [height, width] no : list[int] The number of the light arary [number of lights in height , number of lights inwidth] tilt : torch.tensor The tilt angles in degrees along x, y, z axes for the rays, shape [3]. angle_limit : float The maximum angle in degrees from the initial direction vector within which to emit rays. num_rays_per_light : int The number of rays each light should emit.

Returns:

rays : torch.tensor Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_grid_w_luminous_angle(center, size, no, tilt, num_ray_per_light, angle_limit):\n    \"\"\"\n    Generate a 2D array of lights, each emitting rays within a specified solid angle and tilt.\n\n    Parameters:\n    ----------\n    center              : torch.tensor\n                          The center point of the light array, shape [3].\n    size                : list[int]\n                          The size of the light array [height, width]\n    no                  : list[int]\n                          The number of the light arary [number of lights in height , number of lights inwidth]\n    tilt                : torch.tensor\n                          The tilt angles in degrees along x, y, z axes for the rays, shape [3].\n    angle_limit         : float\n                          The maximum angle in degrees from the initial direction vector within which to emit rays.\n    num_rays_per_light  : int\n                          The number of rays each light should emit.\n\n    Returns:\n    ----------\n    rays : torch.tensor\n           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]\n    \"\"\"\n\n    samples = torch.zeros((no[0], no[1], 3))\n\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((no[0]*no[1], 3))\n\n    samples, *_ = rotate_points(samples, angles=tilt)\n\n    samples = samples + center\n    angle_limit = torch.as_tensor(angle_limit)\n    cos_alpha = torch.cos(angle_limit * torch.pi / 180)\n    tilt = tilt * torch.pi / 180\n\n    theta = torch.acos(1 - 2 * torch.rand(num_ray_per_light*samples.size(0)) * (1-cos_alpha))\n    phi = 2 * torch.pi * torch.rand(num_ray_per_light*samples.size(0))  \n\n    directions = torch.stack([\n        torch.sin(theta) * torch.cos(phi),  \n        torch.sin(theta) * torch.sin(phi),  \n        torch.cos(theta)                    \n    ], dim=1)\n\n    c, s = torch.cos(tilt), torch.sin(tilt)\n\n    Rx = torch.tensor([\n        [1, 0, 0],\n        [0, c[0], -s[0]],\n        [0, s[0], c[0]]\n    ])\n\n    Ry = torch.tensor([\n        [c[1], 0, s[1]],\n        [0, 1, 0],\n        [-s[1], 0, c[1]]\n    ])\n\n    Rz = torch.tensor([\n        [c[2], -s[2], 0],\n        [s[2], c[2], 0],\n        [0, 0, 1]\n    ])\n\n    origins = samples.repeat(num_ray_per_light, 1)\n\n    directions = torch.matmul(directions, (Rz@Ry@Rx).T)\n\n\n    rays = torch.zeros(num_ray_per_light*samples.size(0), 2, 3)\n    rays[:, 0, :] = origins\n    rays[:, 1, :] = directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.create_ray_from_point_w_luminous_angle","title":"create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit)","text":"

Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.

Parameters:

origin : torch.tensor The origin point of the rays, shape [3]. num_rays : int The total number of rays to generate. tilt : torch.tensor The tilt angles in degrees along x, y, z axes, shape [3]. angle_limit : float The maximum angle in degrees from the initial direction vector within which to emit rays.

Returns:

rays : torch.tensor Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]

Source code in odak/learn/raytracing/ray.py
def create_ray_from_point_w_luminous_angle(origin, num_ray, tilt, angle_limit):\n    \"\"\"\n    Generate rays from a point, tilted by specific angles along x, y, z axes, within a specified solid angle.\n\n    Parameters:\n    ----------\n    origin      : torch.tensor\n                  The origin point of the rays, shape [3].\n    num_rays    : int\n                  The total number of rays to generate.\n    tilt        : torch.tensor\n                  The tilt angles in degrees along x, y, z axes, shape [3].\n    angle_limit : float\n                  The maximum angle in degrees from the initial direction vector within which to emit rays.\n\n    Returns:\n    ----------\n    rays : torch.tensor\n           Array that contains starting points and cosines of a created ray(s). Size of [n x 2 x 3]\n    \"\"\"\n    angle_limit = torch.as_tensor(angle_limit) \n    cos_alpha = torch.cos(angle_limit * torch.pi / 180)\n    tilt = tilt * torch.pi / 180\n\n    theta = torch.acos(1 - 2 * torch.rand(num_ray) * (1-cos_alpha))\n    phi = 2 * torch.pi * torch.rand(num_ray)  \n\n\n    directions = torch.stack([\n        torch.sin(theta) * torch.cos(phi),  \n        torch.sin(theta) * torch.sin(phi),  \n        torch.cos(theta)                    \n    ], dim=1)\n\n    c, s = torch.cos(tilt), torch.sin(tilt)\n\n    Rx = torch.tensor([\n        [1, 0, 0],\n        [0, c[0], -s[0]],\n        [0, s[0], c[0]]\n    ])\n\n    Ry = torch.tensor([\n        [c[1], 0, s[1]],\n        [0, 1, 0],\n        [-s[1], 0, c[1]]\n    ])\n\n    Rz = torch.tensor([\n        [c[2], -s[2], 0],\n        [s[2], c[2], 0],\n        [0, 0, 1]\n    ])\n\n    origins = origin.repeat(num_ray, 1)\n    directions = torch.matmul(directions, (Rz@Ry@Rx).T)\n\n\n    rays = torch.zeros(num_ray, 2, 3)\n    rays[:, 0, :] = origins\n    rays[:, 1, :] = directions\n\n    return rays\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.create_ray_from_two_points","title":"create_ray_from_two_points(x0y0z0, x1y1z1)","text":"

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray.\n       Size could be [1 x 3], [3], [m x 3].\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray or batch of rays.\n       Size could be [1 x 3], [3], [m x 3].\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray(s).

Source code in odak/learn/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : torch.tensor\n                   List that contains X,Y and Z start locations of a ray.\n                   Size could be [1 x 3], [3], [m x 3].\n    x1y1z1       : torch.tensor\n                   List that contains X,Y and Z ending locations of a ray or batch of rays.\n                   Size could be [1 x 3], [3], [m x 3].\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray(s).\n    \"\"\"\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.unsqueeze(0)\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.unsqueeze(0)\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = (xdiff ** 2 + ydiff ** 2 + zdiff ** 2) ** 0.5\n    s[s == 0] = float('nan')\n    cosines = torch.zeros_like(x0y0z0 * x1y1z1)\n    cosines[:, 0] = xdiff / s\n    cosines[:, 1] = ydiff / s\n    cosines[:, 2] = zdiff / s\n    ray = torch.zeros(xdiff.shape[0], 2, 3, device = x0y0z0.device)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    return ray\n
"},{"location":"odak/learn_raytracing/#odak.learn.raytracing.ray.propagate_ray","title":"propagate_ray(ray, distance)","text":"

Definition to propagate a ray at a certain given distance.

Parameters:

  • ray \u2013
         A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n
  • distance \u2013
         Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n

Returns:

  • new_ray ( tensor ) \u2013

    Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].

Source code in odak/learn/raytracing/ray.py
def propagate_ray(ray, distance):\n    \"\"\"\n    Definition to propagate a ray at a certain given distance.\n\n    Parameters\n    ----------\n    ray        : torch.tensor\n                 A ray with a size of [2 x 3], [1 x 2 x 3] or a batch of rays with [m x 2 x 3].\n    distance   : torch.tensor\n                 Distance with a size of [1], [1, m] or distances with a size of [m], [1, m].\n\n    Returns\n    ----------\n    new_ray    : torch.tensor\n                 Propagated ray with a size of [1 x 2 x 3] or batch of rays with [m x 2 x 3].\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.unsqueeze(0)\n    if len(distance.shape) == 2:\n        distance = distance.squeeze(-1)\n    new_ray = torch.zeros_like(ray)\n    new_ray[:, 0, 0] = distance * ray[:, 1, 0] + ray[:, 0, 0]\n    new_ray[:, 0, 1] = distance * ray[:, 1, 1] + ray[:, 0, 1]\n    new_ray[:, 0, 2] = distance * ray[:, 1, 2] + ray[:, 0, 2]\n    return new_ray\n
"},{"location":"odak/learn_tools/","title":"odak.learn.tools","text":"

odak.learn.tools

Provides necessary definitions for general tools used across the library.

"},{"location":"odak/learn_tools/#odak.learn.tools.blur_gaussian","title":"blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3], padding='same')","text":"

A definition to blur a field using a Gaussian kernel.

Parameters:

  • field \u2013
            MxN field.\n
  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n
  • padding \u2013
            Padding value, see torch.nn.functional.conv2d() for more.\n

Returns:

  • blurred_field ( tensor ) \u2013

    Blurred field.

Source code in odak/learn/tools/matrix.py
def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'):\n    \"\"\"\n    A definition to blur a field using a Gaussian kernel.\n\n    Parameters\n    ----------\n    field         : torch.tensor\n                    MxN field.\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n    padding       : int or string\n                    Padding value, see torch.nn.functional.conv2d() for more.\n\n    Returns\n    ----------\n    blurred_field : torch.tensor\n                    Blurred field.\n    \"\"\"\n    kernel = generate_2d_gaussian(kernel_length, nsigma).to(field.device)\n    kernel = kernel.unsqueeze(0).unsqueeze(0)\n    if len(field.shape) == 2:\n        field = field.view(1, 1, field.shape[-2], field.shape[-1])\n    blurred_field = torch.nn.functional.conv2d(field, kernel, padding='same')\n    if field.shape[1] == 1:\n        blurred_field = blurred_field.view(\n                                           blurred_field.shape[-2],\n                                           blurred_field.shape[-1]\n                                          )\n    return blurred_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.circular_binary_mask","title":"circular_binary_mask(px, py, r)","text":"

Definition to generate a 2D circular binary mask.

Parameter

px : int Pixel count in x. py : int Pixel count in y. r : int Radius of the circle.

Returns:

  • mask ( tensor ) \u2013

    Mask [1 x 1 x m x n].

Source code in odak/learn/tools/mask.py
def circular_binary_mask(px, py, r):\n    \"\"\"\n    Definition to generate a 2D circular binary mask.\n\n    Parameter\n    ---------\n    px           : int\n                   Pixel count in x.\n    py           : int\n                   Pixel count in y.\n    r            : int\n                   Radius of the circle.\n\n    Returns\n    -------\n    mask         : torch.tensor\n                   Mask [1 x 1 x m x n].\n    \"\"\"\n    x = torch.linspace(-px / 2., px / 2., px)\n    y = torch.linspace(-py / 2., py / 2., py)\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    Z = (X ** 2 + Y ** 2) ** 0.5\n    mask = torch.zeros_like(Z)\n    mask[Z < r] = 1\n    return mask\n
"},{"location":"odak/learn_tools/#odak.learn.tools.convolve2d","title":"convolve2d(field, kernel)","text":"

Definition to convolve a field with a kernel by multiplying in frequency space.

Parameters:

  • field \u2013
          Input field with MxN shape.\n
  • kernel \u2013
          Input kernel with MxN shape.\n

Returns:

  • new_field ( tensor ) \u2013

    Convolved field.

Source code in odak/learn/tools/matrix.py
def convolve2d(field, kernel):\n    \"\"\"\n    Definition to convolve a field with a kernel by multiplying in frequency space.\n\n    Parameters\n    ----------\n    field       : torch.tensor\n                  Input field with MxN shape.\n    kernel      : torch.tensor\n                  Input kernel with MxN shape.\n\n    Returns\n    ----------\n    new_field   : torch.tensor\n                  Convolved field.\n    \"\"\"\n    fr = torch.fft.fft2(field)\n    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))\n    m, n = fr.shape\n    new_field = torch.real(torch.fft.ifft2(fr*fr2))\n    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))\n    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))\n    return new_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.correlation_2d","title":"correlation_2d(first_tensor, second_tensor)","text":"

Definition to calculate the correlation between two tensors.

Parameters:

  • first_tensor \u2013
            First tensor.\n
  • second_tensor (tensor) \u2013
            Second tensor.\n

Returns:

  • correlation ( tensor ) \u2013

    Correlation between the two tensors.

Source code in odak/learn/tools/matrix.py
def correlation_2d(first_tensor, second_tensor):\n    \"\"\"\n    Definition to calculate the correlation between two tensors.\n\n    Parameters\n    ----------\n    first_tensor  : torch.tensor\n                    First tensor.\n    second_tensor : torch.tensor\n                    Second tensor.\n\n    Returns\n    ----------\n    correlation   : torch.tensor\n                    Correlation between the two tensors.\n    \"\"\"\n    fft_first_tensor = (torch.fft.fft2(first_tensor))\n    fft_second_tensor = (torch.fft.fft2(second_tensor))\n    conjugate_second_tensor = torch.conj(fft_second_tensor)\n    result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor))\n    return result\n
"},{"location":"odak/learn_tools/#odak.learn.tools.crop_center","title":"crop_center(field, size=None)","text":"

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

Parameters:

  • field \u2013
          Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.\n
  • size \u2013
          Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).\n

Returns:

  • cropped ( ndarray ) \u2013

    Cropped version of the input field.

Source code in odak/learn/tools/matrix.py
def crop_center(field, size = None):\n    \"\"\"\n    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.\n    size        : list\n                  Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).\n\n    Returns\n    ----------\n    cropped     : ndarray\n                  Cropped version of the input field.\n    \"\"\"\n    orig_resolution = field.shape\n    if len(field.shape) < 3:\n        field = field.unsqueeze(0)\n    if len(field.shape) < 4:\n        field = field.unsqueeze(0)\n    permute_flag = False\n    if field.shape[-1] < 5:\n        permute_flag = True\n        field = field.permute(0, 3, 1, 2)\n    if type(size) == type(None):\n        qx = int(field.shape[-2] // 4)\n        qy = int(field.shape[-1] // 4)\n        cropped_padded = field[:, :, qx: qx + field.shape[-2] // 2, qy:qy + field.shape[-1] // 2]\n    else:\n        cx = int(field.shape[-2] // 2)\n        cy = int(field.shape[-1] // 2)\n        hx = int(size[-2] // 2)\n        hy = int(size[-1] // 2)\n        cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy]\n    cropped = cropped_padded\n    if permute_flag:\n        cropped = cropped.permute(0, 2, 3, 1)\n    if len(orig_resolution) == 2:\n        cropped = cropped_padded.squeeze(0).squeeze(0)\n    if len(orig_resolution) == 3:\n        cropped = cropped_padded.squeeze(0)\n    return cropped\n
"},{"location":"odak/learn_tools/#odak.learn.tools.cross_product","title":"cross_product(vector1, vector2)","text":"

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

Parameters:

  • vector1 \u2013
           A vector/ray.\n
  • vector2 \u2013
           A vector/ray.\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/learn/tools/vector.py
def cross_product(vector1, vector2):\n    \"\"\"\n    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product\n\n    Parameters\n    ----------\n    vector1      : torch.tensor\n                   A vector/ray.\n    vector2      : torch.tensor\n                   A vector/ray.\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    angle = torch.cross(vector1[1].T, vector2[1].T)\n    angle = torch.tensor(angle)\n    ray = torch.tensor([vector1[0], angle], dtype=torch.float32)\n    return ray\n
"},{"location":"odak/learn_tools/#odak.learn.tools.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( Tensor ) \u2013

    Distance in between given two points.

Source code in odak/learn/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : torch.Tensor\n                  First point in X,Y,Z.\n    point2      : torch.Tensor\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : torch.Tensor\n                  Distance in between given two points.\n    \"\"\"\n    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1\n    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2\n\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))\n\n    return distance\n
"},{"location":"odak/learn_tools/#odak.learn.tools.expanduser","title":"expanduser(filename)","text":"

Definition to decode filename using namespaces and shortcuts.

Parameters:

  • filename \u2013
            Filename.\n

Returns:

  • new_filename ( str ) \u2013

    Filename.

Source code in odak/tools/file.py
def expanduser(filename):\n    \"\"\"\n    Definition to decode filename using namespaces and shortcuts.\n\n\n    Parameters\n    ----------\n    filename      : str\n                    Filename.\n\n\n    Returns\n    -------\n    new_filename  : str\n                    Filename.\n    \"\"\"\n    new_filename = os.path.expanduser(filename)\n    return new_filename\n
"},{"location":"odak/learn_tools/#odak.learn.tools.generate_2d_dirac_delta","title":"generate_2d_dirac_delta(kernel_length=[21, 21], a=[3, 3], mu=[0, 0], theta=0, normalize=False)","text":"

Generate 2D Dirac delta function by using Gaussian distribution. Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Dirac delta function along X and Y axes.\n
  • a \u2013
            The scale factor in Gaussian distribution to approximate the Dirac delta function. \n        As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.\n
  • mu \u2013
            Mu of the Gaussian kernel along X and Y axes.\n
  • theta \u2013
            The rotation angle of the 2D Dirac delta function.\n
  • normalize \u2013
            If set True, normalize the output.\n

Returns:

  • kernel_2d ( tensor ) \u2013

    Generated 2D Dirac delta function.

Source code in odak/learn/tools/matrix.py
def generate_2d_dirac_delta(\n                            kernel_length = [21, 21],\n                            a = [3, 3],\n                            mu = [0, 0],\n                            theta = 0,\n                            normalize = False\n                           ):\n    \"\"\"\n    Generate 2D Dirac delta function by using Gaussian distribution.\n    Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Dirac delta function along X and Y axes.\n    a             : list\n                    The scale factor in Gaussian distribution to approximate the Dirac delta function. \n                    As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.\n    mu            : list\n                    Mu of the Gaussian kernel along X and Y axes.\n    theta         : float\n                    The rotation angle of the 2D Dirac delta function.\n    normalize     : bool\n                    If set True, normalize the output.\n\n    Returns\n    ----------\n    kernel_2d     : torch.tensor\n                    Generated 2D Dirac delta function.\n    \"\"\"\n    x = torch.linspace(-kernel_length[0] / 2., kernel_length[0] / 2., kernel_length[0])\n    y = torch.linspace(-kernel_length[1] / 2., kernel_length[1] / 2., kernel_length[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    X = X - mu[0]\n    Y = Y - mu[1]\n    theta = torch.as_tensor(theta)\n    X_rot = X * torch.cos(theta) - Y * torch.sin(theta)\n    Y_rot = X * torch.sin(theta) + Y * torch.cos(theta)\n    kernel_2d = (1 / (abs(a[0] * a[1]) * torch.pi)) * torch.exp(-((X_rot / a[0]) ** 2 + (Y_rot / a[1]) ** 2))\n    if normalize:\n        kernel_2d = kernel_2d / kernel_2d.max()\n    return kernel_2d\n
"},{"location":"odak/learn_tools/#odak.learn.tools.generate_2d_gaussian","title":"generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3], mu=[0, 0], normalize=False)","text":"

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n
  • mu \u2013
            Mu of the Gaussian kernel along X and Y axes.\n
  • normalize \u2013
            If set True, normalize the output.\n

Returns:

  • kernel_2d ( tensor ) \u2013

    Generated Gaussian kernel.

Source code in odak/learn/tools/matrix.py
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):\n    \"\"\"\n    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n    mu            : list\n                    Mu of the Gaussian kernel along X and Y axes.\n    normalize     : bool\n                    If set True, normalize the output.\n\n    Returns\n    ----------\n    kernel_2d     : torch.tensor\n                    Generated Gaussian kernel.\n    \"\"\"\n    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])\n    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    if nsigma[0] == 0:\n        nsigma[0] = 1e-5\n    if nsigma[1] == 0:\n        nsigma[1] = 1e-5\n    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))\n    if normalize:\n        kernel_2d = kernel_2d / kernel_2d.max()\n    return kernel_2d\n
"},{"location":"odak/learn_tools/#odak.learn.tools.get_rotation_matrix","title":"get_rotation_matrix(tilt_angles=[0.0, 0.0, 0.0], tilt_order='XYZ')","text":"

Function to generate rotation matrix for given tilt angles and tilt order.

Parameters:

  • tilt_angles \u2013
                 Tilt angles in degrees along XYZ axes.\n
  • tilt_order \u2013
                 Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).\n

Returns:

  • rotmat ( tensor ) \u2013

    Rotation matrix.

Source code in odak/learn/tools/transformation.py
def get_rotation_matrix(tilt_angles = [0., 0., 0.], tilt_order = 'XYZ'):\n    \"\"\"\n    Function to generate rotation matrix for given tilt angles and tilt order.\n\n\n    Parameters\n    ----------\n    tilt_angles        : list\n                         Tilt angles in degrees along XYZ axes.\n    tilt_order         : str\n                         Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).\n\n    Returns\n    -------\n    rotmat             : torch.tensor\n                         Rotation matrix.\n    \"\"\"\n    rotx = rotmatx(tilt_angles[0])\n    roty = rotmaty(tilt_angles[1])\n    rotz = rotmatz(tilt_angles[2])\n    if tilt_order =='XYZ':\n        rotmat = torch.mm(rotz,torch.mm(roty, rotx))\n    elif tilt_order == 'XZY':\n        rotmat = torch.mm(roty,torch.mm(rotz, rotx))\n    elif tilt_order == 'ZXY':\n        rotmat = torch.mm(roty,torch.mm(rotx, rotz))\n    elif tilt_order == 'YXZ':\n        rotmat = torch.mm(rotz,torch.mm(rotx, roty))\n    elif tilt_order == 'ZYX':\n         rotmat = torch.mm(rotx,torch.mm(roty, rotz))\n    return rotmat\n
"},{"location":"odak/learn_tools/#odak.learn.tools.grid_sample","title":"grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( tensor ) \u2013

    Samples generated.

  • rotx ( tensor ) \u2013

    Rotation matrix at X axis.

  • roty ( tensor ) \u2013

    Rotation matrix at Y axis.

  • rotz ( tensor ) \u2013

    Rotation matrix at Z axis.

Source code in odak/learn/tools/sample.py
def grid_sample(\n                no = [10, 10],\n                size = [100., 100.], \n                center = [0., 0., 0.], \n                angles = [0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    -------\n    samples     : torch.tensor\n                  Samples generated.\n    rotx        : torch.tensor\n                  Rotation matrix at X axis.\n    roty        : torch.tensor\n                  Rotation matrix at Y axis.\n    rotz        : torch.tensor\n                  Rotation matrix at Z axis.\n    \"\"\"\n    center = torch.tensor(center)\n    angles = torch.tensor(angles)\n    size = torch.tensor(size)\n    samples = torch.zeros((no[0], no[1], 3))\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))\n    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)\n    return samples, rotx, roty, rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.histogram_loss","title":"histogram_loss(frame, ground_truth, bins=32, limits=[0.0, 1.0])","text":"

Function for evaluating a frame against a target using histogram.

Parameters:

  • frame \u2013
               Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • ground_truth \u2013
               Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].\n
  • bins \u2013
               Number of bins.\n
  • limits \u2013
               Limits.\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]):\n    \"\"\"\n    Function for evaluating a frame against a target using histogram.\n\n    Parameters\n    ----------\n    frame            : torch.tensor\n                       Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    ground_truth     : torch.tensor\n                       Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].\n    bins             : int\n                       Number of bins.\n    limits           : list\n                       Limits.\n\n    Returns\n    -------\n    loss             : float\n                       Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0).unsqueeze(0)\n    elif len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n\n    if len(ground_truth.shape) == 2:\n        ground_truth = ground_truth.unsqueeze(0).unsqueeze(0)\n    elif len(ground_truth.shape) == 3:\n        ground_truth = ground_truth.unsqueeze(0)\n\n    histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device)\n    histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device)\n\n    l2 = torch.nn.MSELoss()\n\n    for i in range(frame.shape[1]):\n        histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])\n        histogram_ground_truth[i] = torch.histc(ground_truth[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])\n\n    loss = l2(histogram_frame, histogram_ground_truth)\n\n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.load_image","title":"load_image(fn, normalizeby=0.0, torch_style=False)","text":"

Definition to load an image from a given location as a torch tensor.

Parameters:

  • fn \u2013
           Filename.\n
  • normalizeby \u2013
           Value to to normalize images with. Default value of zero will lead to no normalization.\n
  • torch_style \u2013
           If set True, it will load an image mxnx3 as 3xmxn.\n

Returns:

  • image ( ndarray ) \u2013

    Image loaded as a Numpy array.

Source code in odak/learn/tools/file.py
def load_image(fn, normalizeby = 0., torch_style = False):\n    \"\"\"\n    Definition to load an image from a given location as a torch tensor.\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    normalizeby  : float or optional\n                   Value to to normalize images with. Default value of zero will lead to no normalization.\n    torch_style  : bool or optional\n                   If set True, it will load an image mxnx3 as 3xmxn.\n\n    Returns\n    -------\n    image        :  ndarray\n                    Image loaded as a Numpy array.\n\n    \"\"\"\n    image = odak.tools.load_image(fn, normalizeby = normalizeby, torch_style = torch_style)\n    image = torch.from_numpy(image).float()\n    return image\n
"},{"location":"odak/learn_tools/#odak.learn.tools.michelson_contrast","title":"michelson_contrast(image, roi_high, roi_low)","text":"

A function to calculate michelson contrast ratio of given region of interests of the image.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].\n
  • roi_high \u2013
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n
  • roi_low \u2013
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n

Returns:

  • result ( tensor ) \u2013

    Michelson contrast for the given regions. [1] or [3] depending on input image.

Source code in odak/learn/tools/loss.py
def michelson_contrast(image, roi_high, roi_low):\n    \"\"\"\n    A function to calculate michelson contrast ratio of given region of interests of the image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].\n    roi_high      : torch.tensor\n                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n    roi_low       : torch.tensor\n                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Michelson contrast for the given regions. [1] or [3] depending on input image.\n    \"\"\"\n    if len(image.shape) == 2:\n        image = image.unsqueeze(0)\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]\n    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]\n    high = torch.mean(region_high, dim = (2, 3))\n    low = torch.mean(region_low, dim = (2, 3))\n    result = (high - low) / (high + low)\n    return result.squeeze(0)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.multi_scale_total_variation_loss","title":"multi_scale_total_variation_loss(frame, levels=3)","text":"

Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.

Parameters:

  • frame \u2013
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n
  • levels \u2013
            Number of levels to go in the image pyriamid.\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def multi_scale_total_variation_loss(frame, levels = 3):\n    \"\"\"\n    Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.\n\n    Parameters\n    ----------\n    frame         : torch.tensor\n                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n    levels        : int\n                    Number of levels to go in the image pyriamid.\n\n    Returns\n    -------\n    loss          : float\n                    Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0)\n    if len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n    scale = torch.nn.Upsample(scale_factor = 0.5, mode = 'nearest')\n    level = frame\n    loss = 0\n    for i in range(levels):\n        if i != 0:\n           level = scale(level)\n        loss += total_variation_loss(level) \n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.quantize","title":"quantize(image_field, bits=8, limits=[0.0, 1.0])","text":"

Definition to quantize a image field (0-255, 8 bit) to a certain bits level.

Parameters:

  • image_field (tensor) \u2013
          Input image field between any range.\n
  • bits \u2013
          A value in between one to eight.\n
  • limits \u2013
          The minimum and maximum of the image_field variable.\n

Returns:

  • new_field ( tensor ) \u2013

    Quantized image field.

Source code in odak/learn/tools/matrix.py
def quantize(image_field, bits = 8, limits = [0., 1.]):\n    \"\"\" \n    Definition to quantize a image field (0-255, 8 bit) to a certain bits level.\n\n    Parameters\n    ----------\n    image_field : torch.tensor\n                  Input image field between any range.\n    bits        : int\n                  A value in between one to eight.\n    limits      : list\n                  The minimum and maximum of the image_field variable.\n\n    Returns\n    ----------\n    new_field   : torch.tensor\n                  Quantized image field.\n    \"\"\"\n    normalized_field = (image_field - limits[0]) / (limits[1] - limits[0])\n    divider = 2 ** bits\n    new_field = normalized_field * divider\n    new_field = new_field.int()\n    return new_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.radial_basis_function","title":"radial_basis_function(value, epsilon=0.5)","text":"

Function to pass a value into radial basis function with Gaussian description.

Parameters:

  • value \u2013
               Value(s) to pass to the radial basis function.\n
  • epsilon \u2013
               Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).\n

Returns:

  • output ( tensor ) \u2013

    Output values.

Source code in odak/learn/tools/loss.py
def radial_basis_function(value, epsilon = 0.5):\n    \"\"\"\n    Function to pass a value into radial basis function with Gaussian description.\n\n    Parameters\n    ----------\n    value            : torch.tensor\n                       Value(s) to pass to the radial basis function. \n    epsilon          : float\n                       Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).\n\n    Returns\n    -------\n    output           : torch.tensor\n                       Output values.\n    \"\"\"\n    output = torch.exp((-(epsilon * value)**2))\n    return output\n
"},{"location":"odak/learn_tools/#odak.learn.tools.resize","title":"resize(image, multiplier=0.5, mode='nearest')","text":"

Definition to resize an image.

Parameters:

  • image \u2013
            Image with MxNx3 resolution.\n
  • multiplier \u2013
            Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).\n
  • mode \u2013
            Mode to be used in scaling, nearest, bilinear, etc.\n

Returns:

  • new_image ( tensor ) \u2013

    Resized image.

Source code in odak/learn/tools/file.py
def resize(image, multiplier = 0.5, mode = 'nearest'):\n    \"\"\"\n    Definition to resize an image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image with MxNx3 resolution.\n    multiplier    : float\n                    Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).\n    mode          : str\n                    Mode to be used in scaling, nearest, bilinear, etc.\n\n    Returns\n    -------\n    new_image     : torch.tensor\n                    Resized image.\n\n    \"\"\"\n    scale = torch.nn.Upsample(scale_factor = multiplier, mode = mode)\n    new_image = torch.zeros((int(image.shape[0] * multiplier), int(image.shape[1] * multiplier), 3)).to(image.device)\n    for i in range(3):\n        cache = image[:,:,i].unsqueeze(0)\n        cache = cache.unsqueeze(0)\n        new_cache = scale(cache).unsqueeze(0)\n        new_image[:,:,i] = new_cache.unsqueeze(0)\n    return new_image\n
"},{"location":"odak/learn_tools/#odak.learn.tools.rotate_points","title":"rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]]))","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point with size of [3] or [1, 3] or [m, 3].\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis.\n       There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n       Expected size is [3] or [1, 3].\n
  • offset \u2013
           Shift with the given offset.\n       Expected size is [3] or [1, 3] or [m, 3].\n

Returns:

  • result ( tensor ) \u2013

    Result of the rotation [1 x 3] or [m x 3].

  • rotx ( tensor ) \u2013

    Rotation matrix along X axis [3 x 3].

  • roty ( tensor ) \u2013

    Rotation matrix along Y axis [3 x 3].

  • rotz ( tensor ) \u2013

    Rotation matrix along Z axis [3 x 3].

Source code in odak/learn/tools/transformation.py
def rotate_points(\n                 point,\n                 angles = torch.tensor([[0, 0, 0]]), \n                 mode='XYZ', \n                 origin = torch.tensor([[0, 0, 0]]), \n                 offset = torch.tensor([[0, 0, 0]])\n                ):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : torch.tensor\n                   A point with size of [3] or [1, 3] or [m, 3].\n    angles       : torch.tensor\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis.\n                   There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : torch.tensor\n                   Reference point for a rotation.\n                   Expected size is [3] or [1, 3].\n    offset       : torch.tensor\n                   Shift with the given offset.\n                   Expected size is [3] or [1, 3] or [m, 3].\n\n    Returns\n    ----------\n    result       : torch.tensor\n                   Result of the rotation [1 x 3] or [m x 3].\n    rotx         : torch.tensor\n                   Rotation matrix along X axis [3 x 3].\n    roty         : torch.tensor\n                   Rotation matrix along Y axis [3 x 3].\n    rotz         : torch.tensor\n                   Rotation matrix along Z axis [3 x 3].\n    \"\"\"\n    origin = origin.to(point.device)\n    offset = offset.to(point.device)\n    if len(point.shape) == 1:\n        point = point.unsqueeze(0)\n    if len(angles.shape) == 1:\n        angles = angles.unsqueeze(0)\n    rotx = rotmatx(angles[:, 0])\n    roty = rotmaty(angles[:, 1])\n    rotz = rotmatz(angles[:, 2])\n    new_point = (point - origin).T\n    if mode == 'XYZ':\n        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T\n    elif mode == 'XZY':\n        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T\n    elif mode == 'YXZ':\n        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T\n    elif mode == 'ZXY':\n        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T\n    elif mode == 'ZYX':\n        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T\n    result += origin\n    result += offset\n    return result, rotx, roty, rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.rotmatx","title":"rotmatx(angle)","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotx ( tensor ) \u2013

    Rotation matrix along X axis.

Source code in odak/learn/tools/transformation.py
def rotmatx(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    rotx         : torch.tensor\n                   Rotation matrix along X axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    rotx = torch.stack([\n                        torch.stack([ one,              zero,              zero]),\n                        torch.stack([zero,  torch.cos(angle), -torch.sin(angle)]),\n                        torch.stack([zero,  torch.sin(angle),  torch.cos(angle)])\n                       ]).reshape(3, 3)\n    return rotx\n
"},{"location":"odak/learn_tools/#odak.learn.tools.rotmaty","title":"rotmaty(angle)","text":"

Definition to generate a rotation matrix along Y axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • roty ( tensor ) \u2013

    Rotation matrix along Y axis.

Source code in odak/learn/tools/transformation.py
def rotmaty(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Y axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    roty         : torch.tensor\n                   Rotation matrix along Y axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    roty = torch.stack([\n                        torch.stack([ torch.cos(angle), zero, torch.sin(angle)]),\n                        torch.stack([             zero,  one,             zero]),\n                        torch.stack([-torch.sin(angle), zero, torch.cos(angle)])\n                       ]).reshape(3, 3)\n    return roty\n
"},{"location":"odak/learn_tools/#odak.learn.tools.rotmatz","title":"rotmatz(angle)","text":"

Definition to generate a rotation matrix along Z axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotz ( tensor ) \u2013

    Rotation matrix along Z axis.

Source code in odak/learn/tools/transformation.py
def rotmatz(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Z axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    rotz         : torch.tensor\n                   Rotation matrix along Z axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    rotz = torch.stack([\n                        torch.stack([torch.cos(angle), -torch.sin(angle), zero]),\n                        torch.stack([torch.sin(angle),  torch.cos(angle), zero]),\n                        torch.stack([            zero,              zero,  one])\n                       ]).reshape(3,3)\n    return rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/learn/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = torch.subtract(b, a)\n    p1a = torch.subtract(p1, a)\n    p2a = torch.subtract(p2, a)\n    cp1 = torch.cross(ba, p1a)\n    cp2 = torch.cross(ba, p2a)\n    test = torch.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/learn_tools/#odak.learn.tools.save_image","title":"save_image(fn, img, cmin=0, cmax=255, color_depth=8)","text":"

Definition to save a torch tensor as an image.

Parameters:

  • fn \u2013
           Filename.\n
  • img \u2013
           A numpy array with NxMx3 or NxMx1 shapes.\n
  • cmin \u2013
           Minimum value that will be interpreted as 0 level in the final image.\n
  • cmax \u2013
           Maximum value that will be interpreted as 255 level in the final image.\n
  • color_depth \u2013
           Color depth of an image. Default is eight.\n

Returns:

  • bool ( bool ) \u2013

    True if successful.

Source code in odak/learn/tools/file.py
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):\n    \"\"\"\n    Definition to save a torch tensor as an image.\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    img          : ndarray\n                   A numpy array with NxMx3 or NxMx1 shapes.\n    cmin         : int\n                   Minimum value that will be interpreted as 0 level in the final image.\n    cmax         : int\n                   Maximum value that will be interpreted as 255 level in the final image.\n    color_depth  : int\n                   Color depth of an image. Default is eight.\n\n\n    Returns\n    ----------\n    bool         :  bool\n                    True if successful.\n\n    \"\"\"\n    if len(img.shape) ==  4:\n        img = img.squeeze(0)\n    if len(img.shape) > 2 and torch.argmin(torch.tensor(img.shape)) == 0:\n        new_img = torch.zeros(img.shape[1], img.shape[2], img.shape[0]).to(img.device)\n        for i in range(img.shape[0]):\n            new_img[:, :, i] = img[i].detach().clone()\n        img = new_img.detach().clone()\n    img = img.cpu().detach().numpy()\n    return odak.tools.save_image(fn, img, cmin = cmin, cmax = cmax, color_depth = color_depth)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.save_torch_tensor","title":"save_torch_tensor(fn, tensor)","text":"

Definition to save a torch tensor.

Parameters:

  • fn \u2013
           Filename.\n
  • tensor \u2013
           Torch tensor to be saved.\n
Source code in odak/learn/tools/file.py
def save_torch_tensor(fn, tensor):\n    \"\"\"\n    Definition to save a torch tensor.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    tensor       : torch.tensor\n                   Torch tensor to be saved.\n    \"\"\" \n    torch.save(tensor, expanduser(fn))\n
"},{"location":"odak/learn_tools/#odak.learn.tools.tilt_towards","title":"tilt_towards(location, lookat)","text":"

Definition to tilt surface normal of a plane towards a point.

Parameters:

  • location \u2013
           Center of the plane to be tilted.\n
  • lookat \u2013
           Tilt towards this point.\n

Returns:

  • angles ( list ) \u2013

    Rotation angles in degrees.

Source code in odak/learn/tools/transformation.py
def tilt_towards(location, lookat):\n    \"\"\"\n    Definition to tilt surface normal of a plane towards a point.\n\n    Parameters\n    ----------\n    location     : list\n                   Center of the plane to be tilted.\n    lookat       : list\n                   Tilt towards this point.\n\n    Returns\n    ----------\n    angles       : list\n                   Rotation angles in degrees.\n    \"\"\"\n    dx = location[0] - lookat[0]\n    dy = location[1] - lookat[1]\n    dz = location[2] - lookat[2]\n    dist = torch.sqrt(torch.tensor(dx ** 2 + dy ** 2 + dz ** 2))\n    phi = torch.atan2(torch.tensor(dy), torch.tensor(dx))\n    theta = torch.arccos(dz / dist)\n    angles = [0, float(torch.rad2deg(theta)), float(torch.rad2deg(phi))]\n    return angles\n
"},{"location":"odak/learn_tools/#odak.learn.tools.torch_load","title":"torch_load(fn)","text":"

Definition to load a torch files (*.pt).

Parameters:

  • fn \u2013
           Filename.\n

Returns:

  • data ( any ) \u2013

    See torch.load() for more.

Source code in odak/learn/tools/file.py
def torch_load(fn):\n    \"\"\"\n    Definition to load a torch files (*.pt).\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n\n    Returns\n    -------\n    data         : any\n                   See torch.load() for more.\n    \"\"\"  \n    data = torch.load(expanduser(fn))\n    return data\n
"},{"location":"odak/learn_tools/#odak.learn.tools.total_variation_loss","title":"total_variation_loss(frame)","text":"

Function for evaluating a frame against a target using total variation approach.

Parameters:

  • frame \u2013
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def total_variation_loss(frame):\n    \"\"\"\n    Function for evaluating a frame against a target using total variation approach.\n\n    Parameters\n    ----------\n    frame         : torch.tensor\n                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n\n    Returns\n    -------\n    loss          : float\n                    Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0)\n    if len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n    diff_x = frame[:, :, :, 1:] - frame[:, :, :, :-1]\n    diff_y = frame[:, :, 1:, :] - frame[:, :, :-1, :]\n    pixel_count = frame.shape[0] * frame.shape[1] * frame.shape[2] * frame.shape[3]\n    loss = ((diff_x ** 2).sum() + (diff_y ** 2).sum()) / pixel_count\n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.weber_contrast","title":"weber_contrast(image, roi_high, roi_low)","text":"

A function to calculate weber contrast ratio of given region of interests of the image.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].\n
  • roi_high \u2013
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n
  • roi_low \u2013
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n

Returns:

  • result ( tensor ) \u2013

    Weber contrast for given regions. [1] or [3] depending on input image.

Source code in odak/learn/tools/loss.py
def weber_contrast(image, roi_high, roi_low):\n    \"\"\"\n    A function to calculate weber contrast ratio of given region of interests of the image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].\n    roi_high      : torch.tensor\n                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n    roi_low       : torch.tensor\n                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Weber contrast for given regions. [1] or [3] depending on input image.\n    \"\"\"\n    if len(image.shape) == 2:\n        image = image.unsqueeze(0)\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]\n    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]\n    high = torch.mean(region_high, dim = (2, 3))\n    low = torch.mean(region_low, dim = (2, 3))\n    result = (high - low) / low\n    return result.squeeze(0)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.wrapped_mean_squared_error","title":"wrapped_mean_squared_error(image, ground_truth, reduction='mean')","text":"

A function to calculate the wrapped mean squared error between predicted and target angles.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • ground_truth \u2013
            Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • reduction \u2013
            Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.\n

Returns:

  • wmse ( tensor ) \u2013

    The calculated wrapped mean squared error.

Source code in odak/learn/tools/loss.py
def wrapped_mean_squared_error(image, ground_truth, reduction = 'mean'):\n    \"\"\"\n    A function to calculate the wrapped mean squared error between predicted and target angles.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    ground_truth  : torch.tensor\n                    Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    reduction     : str\n                    Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.\n\n    Returns\n    -------\n    wmse        : torch.tensor\n                  The calculated wrapped mean squared error. \n    \"\"\"\n    sin_diff = torch.sin(image) - torch.sin(ground_truth)\n    cos_diff = torch.cos(image) - torch.cos(ground_truth)\n    loss = (sin_diff**2 + cos_diff**2)\n\n    if reduction == 'mean':\n        return loss.mean()\n    elif reduction == 'sum':\n        return loss.sum()\n    else:\n        raise ValueError(\"Invalid reduction type. Choose 'mean' or 'sum'.\")\n
"},{"location":"odak/learn_tools/#odak.learn.tools.zero_pad","title":"zero_pad(field, size=None, method='center')","text":"

Definition to zero pad a MxN array to 2Mx2N array.

Parameters:

  • field \u2013
                Input field MxN or KxJxMxN or KxMxNxJ array.\n
  • size \u2013
                Size to be zeropadded (e.g., [m, n], last two dimensions only).\n
  • method \u2013
                Zeropad either by placing the content to center or to the left.\n

Returns:

  • field_zero_padded ( ndarray ) \u2013

    Zeropadded version of the input field.

Source code in odak/learn/tools/matrix.py
def zero_pad(field, size = None, method = 'center'):\n    \"\"\"\n    Definition to zero pad a MxN array to 2Mx2N array.\n\n    Parameters\n    ----------\n    field             : ndarray\n                        Input field MxN or KxJxMxN or KxMxNxJ array.\n    size              : list\n                        Size to be zeropadded (e.g., [m, n], last two dimensions only).\n    method            : str\n                        Zeropad either by placing the content to center or to the left.\n\n    Returns\n    ----------\n    field_zero_padded : ndarray\n                        Zeropadded version of the input field.\n    \"\"\"\n    orig_resolution = field.shape\n    if len(field.shape) < 3:\n        field = field.unsqueeze(0)\n    if len(field.shape) < 4:\n        field = field.unsqueeze(0)\n    permute_flag = False\n    if field.shape[-1] < 5:\n        permute_flag = True\n        field = field.permute(0, 3, 1, 2)\n    if type(size) == type(None):\n        resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]]\n    else:\n        resolution = [field.shape[0], field.shape[1], size[0], size[1]]\n    field_zero_padded = torch.zeros(resolution, device = field.device, dtype = field.dtype)\n    if method == 'center':\n       start = [\n                resolution[-2] // 2 - field.shape[-2] // 2,\n                resolution[-1] // 2 - field.shape[-1] // 2\n               ]\n       field_zero_padded[\n                         :, :,\n                         start[0] : start[0] + field.shape[-2],\n                         start[1] : start[1] + field.shape[-1]\n                         ] = field\n    elif method == 'left':\n       field_zero_padded[\n                         :, :,\n                         0: field.shape[-2],\n                         0: field.shape[-1]\n                        ] = field\n    if permute_flag == True:\n        field_zero_padded = field_zero_padded.permute(0, 2, 3, 1)\n    if len(orig_resolution) == 2:\n        field_zero_padded = field_zero_padded.squeeze(0).squeeze(0)\n    if len(orig_resolution) == 3:\n        field_zero_padded = field_zero_padded.squeeze(0)\n    return field_zero_padded\n
"},{"location":"odak/learn_tools/#odak.learn.tools.file.load_image","title":"load_image(fn, normalizeby=0.0, torch_style=False)","text":"

Definition to load an image from a given location as a torch tensor.

Parameters:

  • fn \u2013
           Filename.\n
  • normalizeby \u2013
           Value to to normalize images with. Default value of zero will lead to no normalization.\n
  • torch_style \u2013
           If set True, it will load an image mxnx3 as 3xmxn.\n

Returns:

  • image ( ndarray ) \u2013

    Image loaded as a Numpy array.

Source code in odak/learn/tools/file.py
def load_image(fn, normalizeby = 0., torch_style = False):\n    \"\"\"\n    Definition to load an image from a given location as a torch tensor.\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    normalizeby  : float or optional\n                   Value to to normalize images with. Default value of zero will lead to no normalization.\n    torch_style  : bool or optional\n                   If set True, it will load an image mxnx3 as 3xmxn.\n\n    Returns\n    -------\n    image        :  ndarray\n                    Image loaded as a Numpy array.\n\n    \"\"\"\n    image = odak.tools.load_image(fn, normalizeby = normalizeby, torch_style = torch_style)\n    image = torch.from_numpy(image).float()\n    return image\n
"},{"location":"odak/learn_tools/#odak.learn.tools.file.resize","title":"resize(image, multiplier=0.5, mode='nearest')","text":"

Definition to resize an image.

Parameters:

  • image \u2013
            Image with MxNx3 resolution.\n
  • multiplier \u2013
            Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).\n
  • mode \u2013
            Mode to be used in scaling, nearest, bilinear, etc.\n

Returns:

  • new_image ( tensor ) \u2013

    Resized image.

Source code in odak/learn/tools/file.py
def resize(image, multiplier = 0.5, mode = 'nearest'):\n    \"\"\"\n    Definition to resize an image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image with MxNx3 resolution.\n    multiplier    : float\n                    Multiplier used in resizing operation (e.g., 0.5 is half size in one axis).\n    mode          : str\n                    Mode to be used in scaling, nearest, bilinear, etc.\n\n    Returns\n    -------\n    new_image     : torch.tensor\n                    Resized image.\n\n    \"\"\"\n    scale = torch.nn.Upsample(scale_factor = multiplier, mode = mode)\n    new_image = torch.zeros((int(image.shape[0] * multiplier), int(image.shape[1] * multiplier), 3)).to(image.device)\n    for i in range(3):\n        cache = image[:,:,i].unsqueeze(0)\n        cache = cache.unsqueeze(0)\n        new_cache = scale(cache).unsqueeze(0)\n        new_image[:,:,i] = new_cache.unsqueeze(0)\n    return new_image\n
"},{"location":"odak/learn_tools/#odak.learn.tools.file.save_image","title":"save_image(fn, img, cmin=0, cmax=255, color_depth=8)","text":"

Definition to save a torch tensor as an image.

Parameters:

  • fn \u2013
           Filename.\n
  • img \u2013
           A numpy array with NxMx3 or NxMx1 shapes.\n
  • cmin \u2013
           Minimum value that will be interpreted as 0 level in the final image.\n
  • cmax \u2013
           Maximum value that will be interpreted as 255 level in the final image.\n
  • color_depth \u2013
           Color depth of an image. Default is eight.\n

Returns:

  • bool ( bool ) \u2013

    True if successful.

Source code in odak/learn/tools/file.py
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):\n    \"\"\"\n    Definition to save a torch tensor as an image.\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    img          : ndarray\n                   A numpy array with NxMx3 or NxMx1 shapes.\n    cmin         : int\n                   Minimum value that will be interpreted as 0 level in the final image.\n    cmax         : int\n                   Maximum value that will be interpreted as 255 level in the final image.\n    color_depth  : int\n                   Color depth of an image. Default is eight.\n\n\n    Returns\n    ----------\n    bool         :  bool\n                    True if successful.\n\n    \"\"\"\n    if len(img.shape) ==  4:\n        img = img.squeeze(0)\n    if len(img.shape) > 2 and torch.argmin(torch.tensor(img.shape)) == 0:\n        new_img = torch.zeros(img.shape[1], img.shape[2], img.shape[0]).to(img.device)\n        for i in range(img.shape[0]):\n            new_img[:, :, i] = img[i].detach().clone()\n        img = new_img.detach().clone()\n    img = img.cpu().detach().numpy()\n    return odak.tools.save_image(fn, img, cmin = cmin, cmax = cmax, color_depth = color_depth)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.file.save_torch_tensor","title":"save_torch_tensor(fn, tensor)","text":"

Definition to save a torch tensor.

Parameters:

  • fn \u2013
           Filename.\n
  • tensor \u2013
           Torch tensor to be saved.\n
Source code in odak/learn/tools/file.py
def save_torch_tensor(fn, tensor):\n    \"\"\"\n    Definition to save a torch tensor.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    tensor       : torch.tensor\n                   Torch tensor to be saved.\n    \"\"\" \n    torch.save(tensor, expanduser(fn))\n
"},{"location":"odak/learn_tools/#odak.learn.tools.file.torch_load","title":"torch_load(fn)","text":"

Definition to load a torch files (*.pt).

Parameters:

  • fn \u2013
           Filename.\n

Returns:

  • data ( any ) \u2013

    See torch.load() for more.

Source code in odak/learn/tools/file.py
def torch_load(fn):\n    \"\"\"\n    Definition to load a torch files (*.pt).\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n\n    Returns\n    -------\n    data         : any\n                   See torch.load() for more.\n    \"\"\"  \n    data = torch.load(expanduser(fn))\n    return data\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.histogram_loss","title":"histogram_loss(frame, ground_truth, bins=32, limits=[0.0, 1.0])","text":"

Function for evaluating a frame against a target using histogram.

Parameters:

  • frame \u2013
               Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • ground_truth \u2013
               Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].\n
  • bins \u2013
               Number of bins.\n
  • limits \u2013
               Limits.\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def histogram_loss(frame, ground_truth, bins = 32, limits = [0., 1.]):\n    \"\"\"\n    Function for evaluating a frame against a target using histogram.\n\n    Parameters\n    ----------\n    frame            : torch.tensor\n                       Input frame [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    ground_truth     : torch.tensor\n                       Ground truth [1 x 3 x m x n] or  [3 x m x n] or [1 x m x n] or  [m x n].\n    bins             : int\n                       Number of bins.\n    limits           : list\n                       Limits.\n\n    Returns\n    -------\n    loss             : float\n                       Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0).unsqueeze(0)\n    elif len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n\n    if len(ground_truth.shape) == 2:\n        ground_truth = ground_truth.unsqueeze(0).unsqueeze(0)\n    elif len(ground_truth.shape) == 3:\n        ground_truth = ground_truth.unsqueeze(0)\n\n    histogram_frame = torch.zeros(frame.shape[1], bins).to(frame.device)\n    histogram_ground_truth = torch.zeros(ground_truth.shape[1], bins).to(frame.device)\n\n    l2 = torch.nn.MSELoss()\n\n    for i in range(frame.shape[1]):\n        histogram_frame[i] = torch.histc(frame[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])\n        histogram_ground_truth[i] = torch.histc(ground_truth[:, i].flatten(), bins=bins, min=limits[0], max=limits[1])\n\n    loss = l2(histogram_frame, histogram_ground_truth)\n\n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.michelson_contrast","title":"michelson_contrast(image, roi_high, roi_low)","text":"

A function to calculate michelson contrast ratio of given region of interests of the image.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].\n
  • roi_high \u2013
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n
  • roi_low \u2013
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n

Returns:

  • result ( tensor ) \u2013

    Michelson contrast for the given regions. [1] or [3] depending on input image.

Source code in odak/learn/tools/loss.py
def michelson_contrast(image, roi_high, roi_low):\n    \"\"\"\n    A function to calculate michelson contrast ratio of given region of interests of the image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [m x n].\n    roi_high      : torch.tensor\n                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n    roi_low       : torch.tensor\n                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Michelson contrast for the given regions. [1] or [3] depending on input image.\n    \"\"\"\n    if len(image.shape) == 2:\n        image = image.unsqueeze(0)\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]\n    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]\n    high = torch.mean(region_high, dim = (2, 3))\n    low = torch.mean(region_low, dim = (2, 3))\n    result = (high - low) / (high + low)\n    return result.squeeze(0)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.multi_scale_total_variation_loss","title":"multi_scale_total_variation_loss(frame, levels=3)","text":"

Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.

Parameters:

  • frame \u2013
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n
  • levels \u2013
            Number of levels to go in the image pyriamid.\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def multi_scale_total_variation_loss(frame, levels = 3):\n    \"\"\"\n    Function for evaluating a frame against a target using multi scale total variation approach. Here, multi scale refers to image pyramid of an input frame, where at each level image resolution is half of the previous level.\n\n    Parameters\n    ----------\n    frame         : torch.tensor\n                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n    levels        : int\n                    Number of levels to go in the image pyriamid.\n\n    Returns\n    -------\n    loss          : float\n                    Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0)\n    if len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n    scale = torch.nn.Upsample(scale_factor = 0.5, mode = 'nearest')\n    level = frame\n    loss = 0\n    for i in range(levels):\n        if i != 0:\n           level = scale(level)\n        loss += total_variation_loss(level) \n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.radial_basis_function","title":"radial_basis_function(value, epsilon=0.5)","text":"

Function to pass a value into radial basis function with Gaussian description.

Parameters:

  • value \u2013
               Value(s) to pass to the radial basis function.\n
  • epsilon \u2013
               Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).\n

Returns:

  • output ( tensor ) \u2013

    Output values.

Source code in odak/learn/tools/loss.py
def radial_basis_function(value, epsilon = 0.5):\n    \"\"\"\n    Function to pass a value into radial basis function with Gaussian description.\n\n    Parameters\n    ----------\n    value            : torch.tensor\n                       Value(s) to pass to the radial basis function. \n    epsilon          : float\n                       Epsilon used in the Gaussian radial basis function (e.g., y=e^(-(epsilon x value)^2).\n\n    Returns\n    -------\n    output           : torch.tensor\n                       Output values.\n    \"\"\"\n    output = torch.exp((-(epsilon * value)**2))\n    return output\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.total_variation_loss","title":"total_variation_loss(frame)","text":"

Function for evaluating a frame against a target using total variation approach.

Parameters:

  • frame \u2013
            Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n

Returns:

  • loss ( float ) \u2013

    Loss from evaluation.

Source code in odak/learn/tools/loss.py
def total_variation_loss(frame):\n    \"\"\"\n    Function for evaluating a frame against a target using total variation approach.\n\n    Parameters\n    ----------\n    frame         : torch.tensor\n                    Input frame [1 x 3 x m x n] or [3 x m x n] or [m x n].\n\n    Returns\n    -------\n    loss          : float\n                    Loss from evaluation.\n    \"\"\"\n    if len(frame.shape) == 2:\n        frame = frame.unsqueeze(0)\n    if len(frame.shape) == 3:\n        frame = frame.unsqueeze(0)\n    diff_x = frame[:, :, :, 1:] - frame[:, :, :, :-1]\n    diff_y = frame[:, :, 1:, :] - frame[:, :, :-1, :]\n    pixel_count = frame.shape[0] * frame.shape[1] * frame.shape[2] * frame.shape[3]\n    loss = ((diff_x ** 2).sum() + (diff_y ** 2).sum()) / pixel_count\n    return loss\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.weber_contrast","title":"weber_contrast(image, roi_high, roi_low)","text":"

A function to calculate weber contrast ratio of given region of interests of the image.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].\n
  • roi_high \u2013
            Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n
  • roi_low \u2013
            Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n

Returns:

  • result ( tensor ) \u2013

    Weber contrast for given regions. [1] or [3] depending on input image.

Source code in odak/learn/tools/loss.py
def weber_contrast(image, roi_high, roi_low):\n    \"\"\"\n    A function to calculate weber contrast ratio of given region of interests of the image.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n] or [3 x m x n] or [1 x m x n] or [m x n].\n    roi_high      : torch.tensor\n                    Corner locations of the roi for high intensity area [m_start, m_end, n_start, n_end].\n    roi_low       : torch.tensor\n                    Corner locations of the roi for low intensity area [m_start, m_end, n_start, n_end].\n\n    Returns\n    -------\n    result        : torch.tensor\n                    Weber contrast for given regions. [1] or [3] depending on input image.\n    \"\"\"\n    if len(image.shape) == 2:\n        image = image.unsqueeze(0)\n    if len(image.shape) == 3:\n        image = image.unsqueeze(0)\n    region_low = image[:, :, roi_low[0]:roi_low[1], roi_low[2]:roi_low[3]]\n    region_high = image[:, :, roi_high[0]:roi_high[1], roi_high[2]:roi_high[3]]\n    high = torch.mean(region_high, dim = (2, 3))\n    low = torch.mean(region_low, dim = (2, 3))\n    result = (high - low) / low\n    return result.squeeze(0)\n
"},{"location":"odak/learn_tools/#odak.learn.tools.loss.wrapped_mean_squared_error","title":"wrapped_mean_squared_error(image, ground_truth, reduction='mean')","text":"

A function to calculate the wrapped mean squared error between predicted and target angles.

Parameters:

  • image \u2013
            Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • ground_truth \u2013
            Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n
  • reduction \u2013
            Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.\n

Returns:

  • wmse ( tensor ) \u2013

    The calculated wrapped mean squared error.

Source code in odak/learn/tools/loss.py
def wrapped_mean_squared_error(image, ground_truth, reduction = 'mean'):\n    \"\"\"\n    A function to calculate the wrapped mean squared error between predicted and target angles.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    ground_truth  : torch.tensor\n                    Ground truth to be tested [1 x 3 x m x n]  or [3 x m x n] or [1 x m x n] or [m x n].\n    reduction     : str\n                    Specifies the reduction to apply to the output: 'mean' (default) or 'sum'.\n\n    Returns\n    -------\n    wmse        : torch.tensor\n                  The calculated wrapped mean squared error. \n    \"\"\"\n    sin_diff = torch.sin(image) - torch.sin(ground_truth)\n    cos_diff = torch.cos(image) - torch.cos(ground_truth)\n    loss = (sin_diff**2 + cos_diff**2)\n\n    if reduction == 'mean':\n        return loss.mean()\n    elif reduction == 'sum':\n        return loss.sum()\n    else:\n        raise ValueError(\"Invalid reduction type. Choose 'mean' or 'sum'.\")\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.blur_gaussian","title":"blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3], padding='same')","text":"

A definition to blur a field using a Gaussian kernel.

Parameters:

  • field \u2013
            MxN field.\n
  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n
  • padding \u2013
            Padding value, see torch.nn.functional.conv2d() for more.\n

Returns:

  • blurred_field ( tensor ) \u2013

    Blurred field.

Source code in odak/learn/tools/matrix.py
def blur_gaussian(field, kernel_length = [21, 21], nsigma = [3, 3], padding = 'same'):\n    \"\"\"\n    A definition to blur a field using a Gaussian kernel.\n\n    Parameters\n    ----------\n    field         : torch.tensor\n                    MxN field.\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n    padding       : int or string\n                    Padding value, see torch.nn.functional.conv2d() for more.\n\n    Returns\n    ----------\n    blurred_field : torch.tensor\n                    Blurred field.\n    \"\"\"\n    kernel = generate_2d_gaussian(kernel_length, nsigma).to(field.device)\n    kernel = kernel.unsqueeze(0).unsqueeze(0)\n    if len(field.shape) == 2:\n        field = field.view(1, 1, field.shape[-2], field.shape[-1])\n    blurred_field = torch.nn.functional.conv2d(field, kernel, padding='same')\n    if field.shape[1] == 1:\n        blurred_field = blurred_field.view(\n                                           blurred_field.shape[-2],\n                                           blurred_field.shape[-1]\n                                          )\n    return blurred_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.convolve2d","title":"convolve2d(field, kernel)","text":"

Definition to convolve a field with a kernel by multiplying in frequency space.

Parameters:

  • field \u2013
          Input field with MxN shape.\n
  • kernel \u2013
          Input kernel with MxN shape.\n

Returns:

  • new_field ( tensor ) \u2013

    Convolved field.

Source code in odak/learn/tools/matrix.py
def convolve2d(field, kernel):\n    \"\"\"\n    Definition to convolve a field with a kernel by multiplying in frequency space.\n\n    Parameters\n    ----------\n    field       : torch.tensor\n                  Input field with MxN shape.\n    kernel      : torch.tensor\n                  Input kernel with MxN shape.\n\n    Returns\n    ----------\n    new_field   : torch.tensor\n                  Convolved field.\n    \"\"\"\n    fr = torch.fft.fft2(field)\n    fr2 = torch.fft.fft2(torch.flip(torch.flip(kernel, [1, 0]), [0, 1]))\n    m, n = fr.shape\n    new_field = torch.real(torch.fft.ifft2(fr*fr2))\n    new_field = torch.roll(new_field, shifts=(int(n/2+1), 0), dims=(1, 0))\n    new_field = torch.roll(new_field, shifts=(int(m/2+1), 0), dims=(0, 1))\n    return new_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.correlation_2d","title":"correlation_2d(first_tensor, second_tensor)","text":"

Definition to calculate the correlation between two tensors.

Parameters:

  • first_tensor \u2013
            First tensor.\n
  • second_tensor (tensor) \u2013
            Second tensor.\n

Returns:

  • correlation ( tensor ) \u2013

    Correlation between the two tensors.

Source code in odak/learn/tools/matrix.py
def correlation_2d(first_tensor, second_tensor):\n    \"\"\"\n    Definition to calculate the correlation between two tensors.\n\n    Parameters\n    ----------\n    first_tensor  : torch.tensor\n                    First tensor.\n    second_tensor : torch.tensor\n                    Second tensor.\n\n    Returns\n    ----------\n    correlation   : torch.tensor\n                    Correlation between the two tensors.\n    \"\"\"\n    fft_first_tensor = (torch.fft.fft2(first_tensor))\n    fft_second_tensor = (torch.fft.fft2(second_tensor))\n    conjugate_second_tensor = torch.conj(fft_second_tensor)\n    result = torch.fft.ifftshift(torch.fft.ifft2(fft_first_tensor * conjugate_second_tensor))\n    return result\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.crop_center","title":"crop_center(field, size=None)","text":"

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

Parameters:

  • field \u2013
          Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.\n
  • size \u2013
          Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).\n

Returns:

  • cropped ( ndarray ) \u2013

    Cropped version of the input field.

Source code in odak/learn/tools/matrix.py
def crop_center(field, size = None):\n    \"\"\"\n    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field 2M x 2N or K x L x 2M x 2N or K x 2M x 2N x L array.\n    size        : list\n                  Dimensions to crop with respect to center of the image (e.g., M x N or 1 x 1 x M x N).\n\n    Returns\n    ----------\n    cropped     : ndarray\n                  Cropped version of the input field.\n    \"\"\"\n    orig_resolution = field.shape\n    if len(field.shape) < 3:\n        field = field.unsqueeze(0)\n    if len(field.shape) < 4:\n        field = field.unsqueeze(0)\n    permute_flag = False\n    if field.shape[-1] < 5:\n        permute_flag = True\n        field = field.permute(0, 3, 1, 2)\n    if type(size) == type(None):\n        qx = int(field.shape[-2] // 4)\n        qy = int(field.shape[-1] // 4)\n        cropped_padded = field[:, :, qx: qx + field.shape[-2] // 2, qy:qy + field.shape[-1] // 2]\n    else:\n        cx = int(field.shape[-2] // 2)\n        cy = int(field.shape[-1] // 2)\n        hx = int(size[-2] // 2)\n        hy = int(size[-1] // 2)\n        cropped_padded = field[:, :, cx-hx:cx+hx, cy-hy:cy+hy]\n    cropped = cropped_padded\n    if permute_flag:\n        cropped = cropped.permute(0, 2, 3, 1)\n    if len(orig_resolution) == 2:\n        cropped = cropped_padded.squeeze(0).squeeze(0)\n    if len(orig_resolution) == 3:\n        cropped = cropped_padded.squeeze(0)\n    return cropped\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.generate_2d_dirac_delta","title":"generate_2d_dirac_delta(kernel_length=[21, 21], a=[3, 3], mu=[0, 0], theta=0, normalize=False)","text":"

Generate 2D Dirac delta function by using Gaussian distribution. Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Dirac delta function along X and Y axes.\n
  • a \u2013
            The scale factor in Gaussian distribution to approximate the Dirac delta function. \n        As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.\n
  • mu \u2013
            Mu of the Gaussian kernel along X and Y axes.\n
  • theta \u2013
            The rotation angle of the 2D Dirac delta function.\n
  • normalize \u2013
            If set True, normalize the output.\n

Returns:

  • kernel_2d ( tensor ) \u2013

    Generated 2D Dirac delta function.

Source code in odak/learn/tools/matrix.py
def generate_2d_dirac_delta(\n                            kernel_length = [21, 21],\n                            a = [3, 3],\n                            mu = [0, 0],\n                            theta = 0,\n                            normalize = False\n                           ):\n    \"\"\"\n    Generate 2D Dirac delta function by using Gaussian distribution.\n    Inspired from https://en.wikipedia.org/wiki/Dirac_delta_function\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Dirac delta function along X and Y axes.\n    a             : list\n                    The scale factor in Gaussian distribution to approximate the Dirac delta function. \n                    As a approaches zero, the Gaussian distribution becomes infinitely narrow and tall at the center (x=0), approaching the Dirac delta function.\n    mu            : list\n                    Mu of the Gaussian kernel along X and Y axes.\n    theta         : float\n                    The rotation angle of the 2D Dirac delta function.\n    normalize     : bool\n                    If set True, normalize the output.\n\n    Returns\n    ----------\n    kernel_2d     : torch.tensor\n                    Generated 2D Dirac delta function.\n    \"\"\"\n    x = torch.linspace(-kernel_length[0] / 2., kernel_length[0] / 2., kernel_length[0])\n    y = torch.linspace(-kernel_length[1] / 2., kernel_length[1] / 2., kernel_length[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    X = X - mu[0]\n    Y = Y - mu[1]\n    theta = torch.as_tensor(theta)\n    X_rot = X * torch.cos(theta) - Y * torch.sin(theta)\n    Y_rot = X * torch.sin(theta) + Y * torch.cos(theta)\n    kernel_2d = (1 / (abs(a[0] * a[1]) * torch.pi)) * torch.exp(-((X_rot / a[0]) ** 2 + (Y_rot / a[1]) ** 2))\n    if normalize:\n        kernel_2d = kernel_2d / kernel_2d.max()\n    return kernel_2d\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.generate_2d_gaussian","title":"generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3], mu=[0, 0], normalize=False)","text":"

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n
  • mu \u2013
            Mu of the Gaussian kernel along X and Y axes.\n
  • normalize \u2013
            If set True, normalize the output.\n

Returns:

  • kernel_2d ( tensor ) \u2013

    Generated Gaussian kernel.

Source code in odak/learn/tools/matrix.py
def generate_2d_gaussian(kernel_length = [21, 21], nsigma = [3, 3], mu = [0, 0], normalize = False):\n    \"\"\"\n    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n    mu            : list\n                    Mu of the Gaussian kernel along X and Y axes.\n    normalize     : bool\n                    If set True, normalize the output.\n\n    Returns\n    ----------\n    kernel_2d     : torch.tensor\n                    Generated Gaussian kernel.\n    \"\"\"\n    x = torch.linspace(-kernel_length[0]/2., kernel_length[0]/2., kernel_length[0])\n    y = torch.linspace(-kernel_length[1]/2., kernel_length[1]/2., kernel_length[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    if nsigma[0] == 0:\n        nsigma[0] = 1e-5\n    if nsigma[1] == 0:\n        nsigma[1] = 1e-5\n    kernel_2d = 1. / (2. * torch.pi * nsigma[0] * nsigma[1]) * torch.exp(-((X - mu[0])**2. / (2. * nsigma[0]**2.) + (Y - mu[1])**2. / (2. * nsigma[1]**2.)))\n    if normalize:\n        kernel_2d = kernel_2d / kernel_2d.max()\n    return kernel_2d\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.quantize","title":"quantize(image_field, bits=8, limits=[0.0, 1.0])","text":"

Definition to quantize a image field (0-255, 8 bit) to a certain bits level.

Parameters:

  • image_field (tensor) \u2013
          Input image field between any range.\n
  • bits \u2013
          A value in between one to eight.\n
  • limits \u2013
          The minimum and maximum of the image_field variable.\n

Returns:

  • new_field ( tensor ) \u2013

    Quantized image field.

Source code in odak/learn/tools/matrix.py
def quantize(image_field, bits = 8, limits = [0., 1.]):\n    \"\"\" \n    Definition to quantize a image field (0-255, 8 bit) to a certain bits level.\n\n    Parameters\n    ----------\n    image_field : torch.tensor\n                  Input image field between any range.\n    bits        : int\n                  A value in between one to eight.\n    limits      : list\n                  The minimum and maximum of the image_field variable.\n\n    Returns\n    ----------\n    new_field   : torch.tensor\n                  Quantized image field.\n    \"\"\"\n    normalized_field = (image_field - limits[0]) / (limits[1] - limits[0])\n    divider = 2 ** bits\n    new_field = normalized_field * divider\n    new_field = new_field.int()\n    return new_field\n
"},{"location":"odak/learn_tools/#odak.learn.tools.matrix.zero_pad","title":"zero_pad(field, size=None, method='center')","text":"

Definition to zero pad a MxN array to 2Mx2N array.

Parameters:

  • field \u2013
                Input field MxN or KxJxMxN or KxMxNxJ array.\n
  • size \u2013
                Size to be zeropadded (e.g., [m, n], last two dimensions only).\n
  • method \u2013
                Zeropad either by placing the content to center or to the left.\n

Returns:

  • field_zero_padded ( ndarray ) \u2013

    Zeropadded version of the input field.

Source code in odak/learn/tools/matrix.py
def zero_pad(field, size = None, method = 'center'):\n    \"\"\"\n    Definition to zero pad a MxN array to 2Mx2N array.\n\n    Parameters\n    ----------\n    field             : ndarray\n                        Input field MxN or KxJxMxN or KxMxNxJ array.\n    size              : list\n                        Size to be zeropadded (e.g., [m, n], last two dimensions only).\n    method            : str\n                        Zeropad either by placing the content to center or to the left.\n\n    Returns\n    ----------\n    field_zero_padded : ndarray\n                        Zeropadded version of the input field.\n    \"\"\"\n    orig_resolution = field.shape\n    if len(field.shape) < 3:\n        field = field.unsqueeze(0)\n    if len(field.shape) < 4:\n        field = field.unsqueeze(0)\n    permute_flag = False\n    if field.shape[-1] < 5:\n        permute_flag = True\n        field = field.permute(0, 3, 1, 2)\n    if type(size) == type(None):\n        resolution = [field.shape[0], field.shape[1], 2 * field.shape[-2], 2 * field.shape[-1]]\n    else:\n        resolution = [field.shape[0], field.shape[1], size[0], size[1]]\n    field_zero_padded = torch.zeros(resolution, device = field.device, dtype = field.dtype)\n    if method == 'center':\n       start = [\n                resolution[-2] // 2 - field.shape[-2] // 2,\n                resolution[-1] // 2 - field.shape[-1] // 2\n               ]\n       field_zero_padded[\n                         :, :,\n                         start[0] : start[0] + field.shape[-2],\n                         start[1] : start[1] + field.shape[-1]\n                         ] = field\n    elif method == 'left':\n       field_zero_padded[\n                         :, :,\n                         0: field.shape[-2],\n                         0: field.shape[-1]\n                        ] = field\n    if permute_flag == True:\n        field_zero_padded = field_zero_padded.permute(0, 2, 3, 1)\n    if len(orig_resolution) == 2:\n        field_zero_padded = field_zero_padded.squeeze(0).squeeze(0)\n    if len(orig_resolution) == 3:\n        field_zero_padded = field_zero_padded.squeeze(0)\n    return field_zero_padded\n
"},{"location":"odak/learn_tools/#odak.learn.tools.sample.grid_sample","title":"grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( tensor ) \u2013

    Samples generated.

  • rotx ( tensor ) \u2013

    Rotation matrix at X axis.

  • roty ( tensor ) \u2013

    Rotation matrix at Y axis.

  • rotz ( tensor ) \u2013

    Rotation matrix at Z axis.

Source code in odak/learn/tools/sample.py
def grid_sample(\n                no = [10, 10],\n                size = [100., 100.], \n                center = [0., 0., 0.], \n                angles = [0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    -------\n    samples     : torch.tensor\n                  Samples generated.\n    rotx        : torch.tensor\n                  Rotation matrix at X axis.\n    roty        : torch.tensor\n                  Rotation matrix at Y axis.\n    rotz        : torch.tensor\n                  Rotation matrix at Z axis.\n    \"\"\"\n    center = torch.tensor(center)\n    angles = torch.tensor(angles)\n    size = torch.tensor(size)\n    samples = torch.zeros((no[0], no[1], 3))\n    x = torch.linspace(-size[0] / 2., size[0] / 2., no[0])\n    y = torch.linspace(-size[1] / 2., size[1] / 2., no[1])\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    samples[:, :, 0] = X.detach().clone()\n    samples[:, :, 1] = Y.detach().clone()\n    samples = samples.reshape((samples.shape[0] * samples.shape[1], samples.shape[2]))\n    samples, rotx, roty, rotz = rotate_points(samples, angles = angles, offset = center)\n    return samples, rotx, roty, rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.get_rotation_matrix","title":"get_rotation_matrix(tilt_angles=[0.0, 0.0, 0.0], tilt_order='XYZ')","text":"

Function to generate rotation matrix for given tilt angles and tilt order.

Parameters:

  • tilt_angles \u2013
                 Tilt angles in degrees along XYZ axes.\n
  • tilt_order \u2013
                 Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).\n

Returns:

  • rotmat ( tensor ) \u2013

    Rotation matrix.

Source code in odak/learn/tools/transformation.py
def get_rotation_matrix(tilt_angles = [0., 0., 0.], tilt_order = 'XYZ'):\n    \"\"\"\n    Function to generate rotation matrix for given tilt angles and tilt order.\n\n\n    Parameters\n    ----------\n    tilt_angles        : list\n                         Tilt angles in degrees along XYZ axes.\n    tilt_order         : str\n                         Rotation order (e.g., XYZ, XZY, ZXY, YXZ, ZYX).\n\n    Returns\n    -------\n    rotmat             : torch.tensor\n                         Rotation matrix.\n    \"\"\"\n    rotx = rotmatx(tilt_angles[0])\n    roty = rotmaty(tilt_angles[1])\n    rotz = rotmatz(tilt_angles[2])\n    if tilt_order =='XYZ':\n        rotmat = torch.mm(rotz,torch.mm(roty, rotx))\n    elif tilt_order == 'XZY':\n        rotmat = torch.mm(roty,torch.mm(rotz, rotx))\n    elif tilt_order == 'ZXY':\n        rotmat = torch.mm(roty,torch.mm(rotx, rotz))\n    elif tilt_order == 'YXZ':\n        rotmat = torch.mm(rotz,torch.mm(rotx, roty))\n    elif tilt_order == 'ZYX':\n         rotmat = torch.mm(rotx,torch.mm(roty, rotz))\n    return rotmat\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.rotate_points","title":"rotate_points(point, angles=torch.tensor([[0, 0, 0]]), mode='XYZ', origin=torch.tensor([[0, 0, 0]]), offset=torch.tensor([[0, 0, 0]]))","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point with size of [3] or [1, 3] or [m, 3].\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis.\n       There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n       Expected size is [3] or [1, 3].\n
  • offset \u2013
           Shift with the given offset.\n       Expected size is [3] or [1, 3] or [m, 3].\n

Returns:

  • result ( tensor ) \u2013

    Result of the rotation [1 x 3] or [m x 3].

  • rotx ( tensor ) \u2013

    Rotation matrix along X axis [3 x 3].

  • roty ( tensor ) \u2013

    Rotation matrix along Y axis [3 x 3].

  • rotz ( tensor ) \u2013

    Rotation matrix along Z axis [3 x 3].

Source code in odak/learn/tools/transformation.py
def rotate_points(\n                 point,\n                 angles = torch.tensor([[0, 0, 0]]), \n                 mode='XYZ', \n                 origin = torch.tensor([[0, 0, 0]]), \n                 offset = torch.tensor([[0, 0, 0]])\n                ):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : torch.tensor\n                   A point with size of [3] or [1, 3] or [m, 3].\n    angles       : torch.tensor\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis.\n                   There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : torch.tensor\n                   Reference point for a rotation.\n                   Expected size is [3] or [1, 3].\n    offset       : torch.tensor\n                   Shift with the given offset.\n                   Expected size is [3] or [1, 3] or [m, 3].\n\n    Returns\n    ----------\n    result       : torch.tensor\n                   Result of the rotation [1 x 3] or [m x 3].\n    rotx         : torch.tensor\n                   Rotation matrix along X axis [3 x 3].\n    roty         : torch.tensor\n                   Rotation matrix along Y axis [3 x 3].\n    rotz         : torch.tensor\n                   Rotation matrix along Z axis [3 x 3].\n    \"\"\"\n    origin = origin.to(point.device)\n    offset = offset.to(point.device)\n    if len(point.shape) == 1:\n        point = point.unsqueeze(0)\n    if len(angles.shape) == 1:\n        angles = angles.unsqueeze(0)\n    rotx = rotmatx(angles[:, 0])\n    roty = rotmaty(angles[:, 1])\n    rotz = rotmatz(angles[:, 2])\n    new_point = (point - origin).T\n    if mode == 'XYZ':\n        result = torch.mm(rotz, torch.mm(roty, torch.mm(rotx, new_point))).T\n    elif mode == 'XZY':\n        result = torch.mm(roty, torch.mm(rotz, torch.mm(rotx, new_point))).T\n    elif mode == 'YXZ':\n        result = torch.mm(rotz, torch.mm(rotx, torch.mm(roty, new_point))).T\n    elif mode == 'ZXY':\n        result = torch.mm(roty, torch.mm(rotx, torch.mm(rotz, new_point))).T\n    elif mode == 'ZYX':\n        result = torch.mm(rotx, torch.mm(roty, torch.mm(rotz, new_point))).T\n    result += origin\n    result += offset\n    return result, rotx, roty, rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.rotmatx","title":"rotmatx(angle)","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotx ( tensor ) \u2013

    Rotation matrix along X axis.

Source code in odak/learn/tools/transformation.py
def rotmatx(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    rotx         : torch.tensor\n                   Rotation matrix along X axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    rotx = torch.stack([\n                        torch.stack([ one,              zero,              zero]),\n                        torch.stack([zero,  torch.cos(angle), -torch.sin(angle)]),\n                        torch.stack([zero,  torch.sin(angle),  torch.cos(angle)])\n                       ]).reshape(3, 3)\n    return rotx\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.rotmaty","title":"rotmaty(angle)","text":"

Definition to generate a rotation matrix along Y axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • roty ( tensor ) \u2013

    Rotation matrix along Y axis.

Source code in odak/learn/tools/transformation.py
def rotmaty(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Y axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    roty         : torch.tensor\n                   Rotation matrix along Y axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    roty = torch.stack([\n                        torch.stack([ torch.cos(angle), zero, torch.sin(angle)]),\n                        torch.stack([             zero,  one,             zero]),\n                        torch.stack([-torch.sin(angle), zero, torch.cos(angle)])\n                       ]).reshape(3, 3)\n    return roty\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.rotmatz","title":"rotmatz(angle)","text":"

Definition to generate a rotation matrix along Z axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotz ( tensor ) \u2013

    Rotation matrix along Z axis.

Source code in odak/learn/tools/transformation.py
def rotmatz(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Z axis.\n\n    Parameters\n    ----------\n    angle        : torch.tensor\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    rotz         : torch.tensor\n                   Rotation matrix along Z axis.\n    \"\"\"\n    angle = torch.deg2rad(angle)\n    one = torch.ones(1, device = angle.device)\n    zero = torch.zeros(1, device = angle.device)\n    rotz = torch.stack([\n                        torch.stack([torch.cos(angle), -torch.sin(angle), zero]),\n                        torch.stack([torch.sin(angle),  torch.cos(angle), zero]),\n                        torch.stack([            zero,              zero,  one])\n                       ]).reshape(3,3)\n    return rotz\n
"},{"location":"odak/learn_tools/#odak.learn.tools.transformation.tilt_towards","title":"tilt_towards(location, lookat)","text":"

Definition to tilt surface normal of a plane towards a point.

Parameters:

  • location \u2013
           Center of the plane to be tilted.\n
  • lookat \u2013
           Tilt towards this point.\n

Returns:

  • angles ( list ) \u2013

    Rotation angles in degrees.

Source code in odak/learn/tools/transformation.py
def tilt_towards(location, lookat):\n    \"\"\"\n    Definition to tilt surface normal of a plane towards a point.\n\n    Parameters\n    ----------\n    location     : list\n                   Center of the plane to be tilted.\n    lookat       : list\n                   Tilt towards this point.\n\n    Returns\n    ----------\n    angles       : list\n                   Rotation angles in degrees.\n    \"\"\"\n    dx = location[0] - lookat[0]\n    dy = location[1] - lookat[1]\n    dz = location[2] - lookat[2]\n    dist = torch.sqrt(torch.tensor(dx ** 2 + dy ** 2 + dz ** 2))\n    phi = torch.atan2(torch.tensor(dy), torch.tensor(dx))\n    theta = torch.arccos(dz / dist)\n    angles = [0, float(torch.rad2deg(theta)), float(torch.rad2deg(phi))]\n    return angles\n
"},{"location":"odak/learn_tools/#odak.learn.tools.vector.cross_product","title":"cross_product(vector1, vector2)","text":"

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

Parameters:

  • vector1 \u2013
           A vector/ray.\n
  • vector2 \u2013
           A vector/ray.\n

Returns:

  • ray ( tensor ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/learn/tools/vector.py
def cross_product(vector1, vector2):\n    \"\"\"\n    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product\n\n    Parameters\n    ----------\n    vector1      : torch.tensor\n                   A vector/ray.\n    vector2      : torch.tensor\n                   A vector/ray.\n\n    Returns\n    ----------\n    ray          : torch.tensor\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    angle = torch.cross(vector1[1].T, vector2[1].T)\n    angle = torch.tensor(angle)\n    ray = torch.tensor([vector1[0], angle], dtype=torch.float32)\n    return ray\n
"},{"location":"odak/learn_tools/#odak.learn.tools.vector.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( Tensor ) \u2013

    Distance in between given two points.

Source code in odak/learn/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : torch.Tensor\n                  First point in X,Y,Z.\n    point2      : torch.Tensor\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : torch.Tensor\n                  Distance in between given two points.\n    \"\"\"\n    point1 = torch.tensor(point1) if not isinstance(point1, torch.Tensor) else point1\n    point2 = torch.tensor(point2) if not isinstance(point2, torch.Tensor) else point2\n\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = torch.sqrt(torch.sum((point1 - point2) ** 2, dim=-1))\n\n    return distance\n
"},{"location":"odak/learn_tools/#odak.learn.tools.vector.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/learn/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = torch.subtract(b, a)\n    p1a = torch.subtract(p1, a)\n    p2a = torch.subtract(p2, a)\n    cp1 = torch.cross(ba, p1a)\n    cp2 = torch.cross(ba, p2a)\n    test = torch.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/learn_wave/","title":"odak.learn.wave","text":""},{"location":"odak/learn_wave/#odak.learn.wave.classical.angular_spectrum","title":"angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)","text":"

A definition to calculate convolution with Angular Spectrum method for beam propagation.

Parameters:

  • field \u2013
               Complex field [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):\n    \"\"\"\n    A definition to calculate convolution with Angular Spectrum method for beam propagation.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Angular Spectrum',\n                               device = field.device\n                              )\n    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.band_limited_angular_spectrum","title":"band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)","text":"

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field \u2013
               A complex field.\n           The expected size is [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def band_limited_angular_spectrum(\n                                  field,\n                                  k,\n                                  distance,\n                                  dx,\n                                  wavelength,\n                                  zero_padding = False,\n                                  aperture = 1.\n                                 ):\n    \"\"\"\n    A definition to calculate bandlimited angular spectrum based beam propagation. For more \n    `Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673`.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       A complex field.\n                       The expected size is [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field [m x n].\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Bandlimited Angular Spectrum',\n                               device = field.device\n                              )\n    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.custom","title":"custom(field, kernel, zero_padding=False, aperture=1.0)","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field [m x n].\n
  • kernel \u2013
               Custom complex kernel for beam propagation.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def custom(field, kernel, zero_padding = False, aperture = 1.):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field [m x n].\n    kernel           : torch.complex\n                       Custom complex kernel for beam propagation.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    if type(kernel) == type(None):\n        H = torch.ones(field.shape).to(field.device)\n    else:\n        H = kernel * aperture\n    U1 = torch.fft.fftshift(torch.fft.fft2(field)) * aperture\n    if zero_padding == False:\n        U2 = H * U1\n    elif zero_padding == True:\n        U2 = zero_pad(H * U1)\n    result = torch.fft.ifft2(torch.fft.ifftshift(U2))\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.fraunhofer","title":"fraunhofer(field, k, distance, dx, wavelength)","text":"

A definition to calculate light transport usin Fraunhofer approximation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate light transport usin Fraunhofer approximation.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape[-1], field.shape[-2]\n    x = torch.linspace(-nv*dx/2, nv*dx/2, nv, dtype=torch.float32)\n    y = torch.linspace(-nu*dx/2, nu*dx/2, nu, dtype=torch.float32)\n    Y, X = torch.meshgrid(y, x, indexing='ij')\n    Z = torch.pow(X, 2) + torch.pow(Y, 2)\n    c = 1. / (1j * wavelength * distance) * torch.exp(1j * k * 0.5 / distance * Z)\n    c = c.to(field.device)\n    result = c * torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(field))) * dx ** 2\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.gerchberg_saxton","title":"gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel')","text":"

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • slm_range \u2013
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n
  • propagation_type (str, default: 'Transfer Function Fresnel' ) \u2013
               Type of the propagation (see odak.learn.wave.propagate_beam).\n

Returns:

  • hologram ( cfloat ) \u2013

    Calculated complex hologram.

  • reconstruction ( cfloat ) \u2013

    Calculated reconstruction using calculated hologram.

Source code in odak/learn/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='Transfer Function Fresnel'):\n    \"\"\"\n    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.\n\n    Parameters\n    ----------\n    field            : torch.cfloat\n                       Complex field (MxN).\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    slm_range        : float\n                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n    propagation_type : str\n                       Type of the propagation (see odak.learn.wave.propagate_beam).\n\n    Returns\n    -------\n    hologram         : torch.cfloat\n                       Calculated complex hologram.\n    reconstruction   : torch.cfloat\n                       Calculated reconstruction using calculated hologram. \n    \"\"\"\n    k = wavenumber(wavelength)\n    reconstruction = field\n    for i in range(n_iterations):\n        hologram = propagate_beam(\n            reconstruction, k, -distance, dx, wavelength, propagation_type)\n        reconstruction = propagate_beam(\n            hologram, k, distance, dx, wavelength, propagation_type)\n        reconstruction = set_amplitude(reconstruction, field)\n    reconstruction = propagate_beam(\n        hologram, k, distance, dx, wavelength, propagation_type)\n    return hologram, reconstruction\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_angular_spectrum_kernel","title":"get_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))","text":"

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n

Returns:

  • H ( float ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):\n    \"\"\"\n    Helper function for odak.learn.wave.angular_spectrum.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n\n\n    Returns\n    -------\n    H                  : float\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    distance = torch.tensor([distance]).to(device)\n    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)\n    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)\n    FY, FX = torch.meshgrid(fx, fy, indexing='ij')\n    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))\n    H = H.to(device)\n    return H\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_band_limited_angular_spectrum_kernel","title":"get_band_limited_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))","text":"

Helper function for odak.learn.wave.band_limited_angular_spectrum.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n

Returns:

  • H ( complex64 ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_band_limited_angular_spectrum_kernel(\n                                             nu,\n                                             nv,\n                                             dx = 8e-6,\n                                             wavelength = 515e-9,\n                                             distance = 0.,\n                                             device = torch.device('cpu')\n                                            ):\n    \"\"\"\n    Helper function for odak.learn.wave.band_limited_angular_spectrum.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n\n\n    Returns\n    -------\n    H                  : torch.complex64\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    x = dx * float(nu)\n    y = dx * float(nv)\n    fx = torch.linspace(\n                        -1 / (2 * dx) + 0.5 / (2 * x),\n                         1 / (2 * dx) - 0.5 / (2 * x),\n                         nu,\n                         dtype = torch.float32,\n                         device = device\n                        )\n    fy = torch.linspace(\n                        -1 / (2 * dx) + 0.5 / (2 * y),\n                        1 / (2 * dx) - 0.5 / (2 * y),\n                        nv,\n                        dtype = torch.float32,\n                        device = device\n                       )\n    FY, FX = torch.meshgrid(fx, fy, indexing='ij')\n    HH_exp = 2 * torch.pi * torch.sqrt(1 / wavelength ** 2 - (FX ** 2 + FY ** 2))\n    distance = torch.tensor([distance], device = device)\n    H_exp = torch.mul(HH_exp, distance)\n    fx_max = 1 / torch.sqrt((2 * distance * (1 / x))**2 + 1) / wavelength\n    fy_max = 1 / torch.sqrt((2 * distance * (1 / y))**2 + 1) / wavelength\n    H_filter = ((torch.abs(FX) < fx_max) & (torch.abs(FY) < fy_max)).clone().detach()\n    H = generate_complex_field(H_filter, H_exp)\n    return H\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_impulse_response_fresnel_kernel","title":"get_impulse_response_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), scale=1, aperture_samples=[20, 20, 5, 5])","text":"

Helper function for odak.learn.wave.impulse_response_fresnel.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n
  • scale \u2013
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).\n
  • aperture_samples \u2013
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.\n

Returns:

  • H ( complex64 ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_impulse_response_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu'), scale = 1, aperture_samples = [20, 20, 5, 5]):\n    \"\"\"\n    Helper function for odak.learn.wave.impulse_response_fresnel.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n    scale              : int\n                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).\n    aperture_samples   : list\n                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.\n\n    Returns\n    -------\n    H                  : torch.complex64\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    k = wavenumber(wavelength)\n    distance = torch.as_tensor(distance, device = device)\n    length_x, length_y = (torch.tensor(dx * nu, device = device), torch.tensor(dx * nv, device = device))\n    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)\n    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)\n    X, Y = torch.meshgrid(x, y, indexing = 'ij')\n    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device)\n    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device)\n    h = torch.zeros(nu * scale, nv * scale, dtype = torch.complex64, device = device)\n    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device)\n    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device)\n    for wx in tqdm(wxs):\n        for wy in wys:\n            for px in pxs:\n                for py in pys:\n                    r = (X + px - wx) ** 2 + (Y + py - wy) ** 2\n                    h += 1. / (1j * wavelength * distance) * torch.exp(1j * k / (2 * distance) * r) \n    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]\n    return H\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_incoherent_angular_spectrum_kernel","title":"get_incoherent_angular_spectrum_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))","text":"

Helper function for odak.learn.wave.angular_spectrum.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n

Returns:

  • H ( float ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_incoherent_angular_spectrum_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):\n    \"\"\"\n    Helper function for odak.learn.wave.angular_spectrum.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n\n\n    Returns\n    -------\n    H                  : float\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    distance = torch.tensor([distance]).to(device)\n    fx = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nu, dtype = torch.float32, device = device)\n    fy = torch.linspace(-1. / 2. / dx, 1. / 2. / dx, nv, dtype = torch.float32, device = device)\n    FY, FX = torch.meshgrid(fx, fy, indexing='ij')\n    H = torch.exp(1j  * distance * (2 * (torch.pi * (1 / wavelength) * torch.sqrt(1. - (wavelength * FX) ** 2 - (wavelength * FY) ** 2))))\n    H_ptime = correlation_2d(H, H)\n    H = H_ptime.to(device)\n    return H\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_light_kernels","title":"get_light_kernels(wavelengths, distances, pixel_pitches, resolution=[1080, 1920], resolution_factor=1, samples=[50, 50, 5, 5], propagation_type='Bandlimited Angular Spectrum', kernel_type='spatial', device=torch.device('cpu'))","text":"

Utility function to request a tensor filled with light transport kernels according to the given optical configurations.

Parameters:

  • wavelengths \u2013
                 A list of wavelengths.\n
  • distances \u2013
                 A list of propagation distances.\n
  • pixel_pitches \u2013
                 A list of pixel_pitches.\n
  • resolution \u2013
                 Resolution of the light transport kernel.\n
  • resolution_factor \u2013
                 If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().\n
  • samples \u2013
                 If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().\n
  • propagation_type \u2013
                 Propagation type. For more, see odak.learn.wave.propagate_beam().\n
  • kernel_type \u2013
                 If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.\n
  • device \u2013
                 Device used for computation (i.e., cpu, cuda).\n

Returns:

  • light_kernels_amplitude ( tensor ) \u2013

    Amplitudes of the light kernels generated [w x d x p x m x n].

  • light_kernels_phase ( tensor ) \u2013

    Phases of the light kernels generated [w x d x p x m x n].

  • light_kernels_complex ( tensor ) \u2013

    Complex light kernels generated [w x d x p x m x n].

  • light_parameters ( tensor ) \u2013

    Parameters of each pixel in light_kernels* [w x d x p x m x n x 5]. Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.

Source code in odak/learn/wave/classical.py
def get_light_kernels(\n                      wavelengths,\n                      distances,\n                      pixel_pitches,\n                      resolution = [1080, 1920],\n                      resolution_factor = 1,\n                      samples = [50, 50, 5, 5],\n                      propagation_type = 'Bandlimited Angular Spectrum',\n                      kernel_type = 'spatial',\n                      device = torch.device('cpu')\n                     ):\n    \"\"\"\n    Utility function to request a tensor filled with light transport kernels according to the given optical configurations.\n\n    Parameters\n    ----------\n    wavelengths        : list\n                         A list of wavelengths.\n    distances          : list\n                         A list of propagation distances.\n    pixel_pitches      : list\n                         A list of pixel_pitches.\n    resolution         : list\n                         Resolution of the light transport kernel.\n    resolution_factor  : int\n                         If `Impulse Response Fresnel` propagation is used, this resolution factor could be set larger than one leading to higher resolution light transport kernels than the provided native `resolution`. For more, see odak.learn.wave.get_impulse_response_kernel().\n    samples            : list\n                         If `Impulse Response Fresnel` propagation is used, these sample counts will be used to calculate the light transport kernel. For more, see odak.learn.wave.get_impulse_response_kernel().\n    propagation_type   : str\n                         Propagation type. For more, see odak.learn.wave.propagate_beam().\n    kernel_type        : str\n                         If set to `spatial`, light transport kernels will be provided in space. But if set to `fourier`, these kernels will be provided in the Fourier domain.\n    device             : torch.device\n                         Device used for computation (i.e., cpu, cuda).\n\n    Returns\n    -------\n    light_kernels_amplitude : torch.tensor\n                              Amplitudes of the light kernels generated [w x d x p x m x n].\n    light_kernels_phase     : torch.tensor\n                              Phases of the light kernels generated [w x d x p x m x n].\n    light_kernels_complex   : torch.tensor\n                              Complex light kernels generated [w x d x p x m x n].\n    light_parameters        : torch.tensor\n                              Parameters of each pixel in light_kernels* [w x d x p x m x n x 5].  Last dimension contains, wavelengths, distances, pixel pitches, X and Y locations in order.\n    \"\"\"\n    if propagation_type != 'Impulse Response Fresnel':\n        resolution_factor = 1\n    light_kernels_complex = torch.zeros(            \n                                        len(wavelengths),\n                                        len(distances),\n                                        len(pixel_pitches),\n                                        resolution[0] * resolution_factor,\n                                        resolution[1] * resolution_factor,\n                                        dtype = torch.complex64,\n                                        device = device\n                                       )\n    light_parameters = torch.zeros(\n                                   len(wavelengths),\n                                   len(distances),\n                                   len(pixel_pitches),\n                                   resolution[0] * resolution_factor,\n                                   resolution[1] * resolution_factor,\n                                   5,\n                                   dtype = torch.float32,\n                                   device = device\n                                  )\n    for wavelength_id, distance_id, pixel_pitch_id in itertools.product(\n                                                                        range(len(wavelengths)),\n                                                                        range(len(distances)),\n                                                                        range(len(pixel_pitches)),\n                                                                       ):\n        pixel_pitch = pixel_pitches[pixel_pitch_id]\n        wavelength = wavelengths[wavelength_id]\n        distance = distances[distance_id]\n        kernel_fourier = get_propagation_kernel(\n                                                nu = resolution[0],\n                                                nv = resolution[1],\n                                                dx = pixel_pitch,\n                                                wavelength = wavelength,\n                                                distance = distance,\n                                                device = device,\n                                                propagation_type = propagation_type,\n                                                scale = resolution_factor,\n                                                samples = samples\n                                               )\n        if kernel_type == 'spatial':\n            kernel = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(kernel_fourier)))\n        elif kernel_type == 'fourier':\n            kernel = kernel_fourier\n        else:\n            logging.warning('Unknown kernel type requested.')\n            raise ValueError('Unknown kernel type requested.')\n        kernel_amplitude = calculate_amplitude(kernel)\n        kernel_phase = calculate_phase(kernel) % (2 * torch.pi)\n        light_kernels_complex[wavelength_id, distance_id, pixel_pitch_id] = kernel\n        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 0] = wavelength\n        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 1] = distance\n        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 2] = pixel_pitch\n        x = torch.linspace(-1., 1., resolution[0] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[0]\n        y = torch.linspace(-1., 1., resolution[1] * resolution_factor, device = device) * pixel_pitch / 2. * resolution[1]\n        X, Y = torch.meshgrid(x, y, indexing = 'ij')\n        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 3] = X\n        light_parameters[wavelength_id, distance_id, pixel_pitch_id, :, :, 4] = Y\n    light_kernels_amplitude = calculate_amplitude(light_kernels_complex)\n    light_kernels_phase = calculate_phase(light_kernels_complex) % (2. * torch.pi)\n    return light_kernels_amplitude, light_kernels_phase, light_kernels_complex, light_parameters\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_point_wise_impulse_response_fresnel_kernel","title":"get_point_wise_impulse_response_fresnel_kernel(aperture_points, aperture_field, target_points, resolution, resolution_factor=1, wavelength=5.15e-07, distance=0.0, randomization=False, device=torch.device('cpu'))","text":"

This function is a freeform point spread function calculation routine for an aperture defined with a complex field, aperture_field, and locations in space, aperture_points. The point spread function is calculated over provided points, target_points. The final result is reshaped to follow the provided resolution.

Parameters:

  • aperture_points \u2013
                       Points representing an aperture in Euler space (XYZ) [m x 3].\n
  • aperture_field \u2013
                       Complex field for each point provided by `aperture_points` [1 x m].\n
  • target_points \u2013
                       Target points where the propagated field will be calculated [n x 1].\n
  • resolution \u2013
                       Final resolution that the propagated field will be reshaped [X x Y].\n
  • resolution_factor \u2013
                       Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.\n
  • wavelength \u2013
                       Wavelength in meters.\n
  • randomization \u2013
                       If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.\n
  • distance \u2013
                       Distance in meters.\n

Returns:

  • h ( float ) \u2013

    Complex field in spatial domain.

Source code in odak/learn/wave/classical.py
def get_point_wise_impulse_response_fresnel_kernel(\n                                                   aperture_points,\n                                                   aperture_field,\n                                                   target_points,\n                                                   resolution,\n                                                   resolution_factor = 1,\n                                                   wavelength = 515e-9,\n                                                   distance = 0.,\n                                                   randomization = False,\n                                                   device = torch.device('cpu')\n                                                  ):\n    \"\"\"\n    This function is a freeform point spread function calculation routine for an aperture defined with a complex field, `aperture_field`, and locations in space, `aperture_points`.\n    The point spread function is calculated over provided points, `target_points`.\n    The final result is reshaped to follow the provided `resolution`.\n\n    Parameters\n    ----------\n    aperture_points          : torch.tensor\n                               Points representing an aperture in Euler space (XYZ) [m x 3].\n    aperture_field           : torch.tensor\n                               Complex field for each point provided by `aperture_points` [1 x m].\n    target_points            : torch.tensor\n                               Target points where the propagated field will be calculated [n x 1].\n    resolution               : list\n                               Final resolution that the propagated field will be reshaped [X x Y].\n    resolution_factor        : int\n                               Scale with respect to `resolution` (e.g., scale = 2 leads to `2 x resolution` for the final complex field.\n    wavelength               : float\n                               Wavelength in meters.\n    randomization            : bool\n                               If set `True`, this will help generate a noisy response roughly approximating a real life case, where imperfections occur.\n    distance                 : float\n                               Distance in meters.\n\n    Returns\n    -------\n    h                        : float\n                               Complex field in spatial domain.\n    \"\"\"\n    device = aperture_field.device\n    k = wavenumber(wavelength)\n    if randomization:\n        pp = [\n              aperture_points[:, 0].max() - aperture_points[:, 0].min(),\n              aperture_points[:, 1].max() - aperture_points[:, 1].min()\n             ]\n        target_points[:, 0] = target_points[:, 0] - torch.randn(target_points[:, 0].shape) * pp[0]\n        target_points[:, 1] = target_points[:, 1] - torch.randn(target_points[:, 1].shape) * pp[1]\n    deltaX = aperture_points[:, 0].unsqueeze(0) - target_points[:, 0].unsqueeze(-1)\n    deltaY = aperture_points[:, 1].unsqueeze(0) - target_points[:, 1].unsqueeze(-1)\n    r = deltaX ** 2 + deltaY ** 2\n    h = torch.exp(1j * k / (2 * distance) * r) * aperture_field\n    h = torch.sum(h, dim = 1).reshape(resolution[0] * resolution_factor, resolution[1] * resolution_factor)\n    h = 1. / (1j * wavelength * distance) * h\n    return h\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_propagation_kernel","title":"get_propagation_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'), propagation_type='Bandlimited Angular Spectrum', scale=1, samples=[20, 20, 5, 5])","text":"

Get propagation kernel for the propagation type.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n
  • propagation_type \u2013
                 Propagation type.\n             The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.\n
  • scale \u2013
                 Scale factor for scaled beam propagation.\n
  • samples \u2013
                 When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n

Returns:

  • kernel ( tensor ) \u2013

    Complex kernel for the given propagation type.

Source code in odak/learn/wave/classical.py
def get_propagation_kernel(\n                           nu, \n                           nv, \n                           dx = 8e-6, \n                           wavelength = 515e-9, \n                           distance = 0., \n                           device = torch.device('cpu'), \n                           propagation_type = 'Bandlimited Angular Spectrum', \n                           scale = 1,\n                           samples = [20, 20, 5, 5]\n                          ):\n    \"\"\"\n    Get propagation kernel for the propagation type.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n    propagation_type   : str\n                         Propagation type.\n                         The options are `Angular Spectrum`, `Bandlimited Angular Spectrum` and `Transfer Function Fresnel`.\n    scale              : int\n                         Scale factor for scaled beam propagation.\n    samples            : list\n                         When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n\n\n    Returns\n    -------\n    kernel             : torch.tensor\n                         Complex kernel for the given propagation type.\n    \"\"\"                                                      \n    logging.warning('Requested propagation kernel size for %s method with %s m distance, %s m pixel pitch, %s m wavelength, %s x %s resolutions, x%s scale and %s samples.'.format(propagation_type, distance, dx, nu, nv, scale, samples))\n    if propagation_type == 'Bandlimited Angular Spectrum':\n        kernel = get_band_limited_angular_spectrum_kernel(\n                                                          nu = nu,\n                                                          nv = nv,\n                                                          dx = dx,\n                                                          wavelength = wavelength,\n                                                          distance = distance,\n                                                          device = device\n                                                         )\n    elif propagation_type == 'Angular Spectrum':\n        kernel = get_angular_spectrum_kernel(\n                                             nu = nu,\n                                             nv = nv,\n                                             dx = dx,\n                                             wavelength = wavelength,\n                                             distance = distance,\n                                             device = device\n                                            )\n    elif propagation_type == 'Transfer Function Fresnel':\n        kernel = get_transfer_function_fresnel_kernel(\n                                                      nu = nu,\n                                                      nv = nv,\n                                                      dx = dx,\n                                                      wavelength = wavelength,\n                                                      distance = distance,\n                                                      device = device\n                                                     )\n    elif propagation_type == 'Impulse Response Fresnel':\n        kernel = get_impulse_response_fresnel_kernel(\n                                                     nu = nu, \n                                                     nv = nv, \n                                                     dx = dx, \n                                                     wavelength = wavelength,\n                                                     distance = distance,\n                                                     device =  device,\n                                                     scale = scale,\n                                                     aperture_samples = samples\n                                                    )\n    elif propagation_type == 'Incoherent Angular Spectrum':\n        kernel = get_incoherent_angular_spectrum_kernel(\n                                                        nu = nu,\n                                                        nv = nv, \n                                                        dx = dx, \n                                                        wavelength = wavelength, \n                                                        distance = distance,\n                                                        device = device\n                                                       )\n    elif propagation_type == 'Seperable Impulse Response Fresnel':\n        kernel, _, _, _ = get_seperable_impulse_response_fresnel_kernel(\n                                                                        nu = nu,\n                                                                        nv = nv,\n                                                                        dx = dx,\n                                                                        wavelength = wavelength,\n                                                                        distance = distance,\n                                                                        device = device,\n                                                                        scale = scale,\n                                                                        aperture_samples = samples\n                                                                       )\n    else:\n        logging.warning('Propagation type not recognized')\n        assert True == False\n    return kernel\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_seperable_impulse_response_fresnel_kernel","title":"get_seperable_impulse_response_fresnel_kernel(nu, nv, dx=3.74e-06, wavelength=5.15e-07, distance=0.0, scale=1, aperture_samples=[50, 50, 5, 5], device=torch.device('cpu'))","text":"

Returns impulse response fresnel kernel in separable form.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n
  • scale \u2013
                 Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).\n
  • aperture_samples \u2013
                 Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.\n

Returns:

  • H ( complex64 ) \u2013

    Complex kernel in Fourier domain.

  • h ( complex64 ) \u2013

    Complex kernel in spatial domain.

  • h_x ( complex64 ) \u2013

    1D complex kernel in spatial domain along X axis.

  • h_y ( complex64 ) \u2013

    1D complex kernel in spatial domain along Y axis.

Source code in odak/learn/wave/classical.py
def get_seperable_impulse_response_fresnel_kernel(\n                                                  nu,\n                                                  nv,\n                                                  dx = 3.74e-6,\n                                                  wavelength = 515e-9,\n                                                  distance = 0.,\n                                                  scale = 1,\n                                                  aperture_samples = [50, 50, 5, 5],\n                                                  device = torch.device('cpu')\n                                                 ):\n    \"\"\"\n    Returns impulse response fresnel kernel in separable form.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n    scale              : int\n                         Scale with respect to nu and nv (e.g., scale = 2 leads to  2 x nu and 2 x nv resolution for H).\n    aperture_samples   : list\n                         Number of samples to represent a rectangular pixel. First two is for XY of hologram plane pixels, and second two is for image plane pixels.\n\n    Returns\n    -------\n    H                  : torch.complex64\n                         Complex kernel in Fourier domain.\n    h                  : torch.complex64\n                         Complex kernel in spatial domain.\n    h_x                : torch.complex64\n                         1D complex kernel in spatial domain along X axis.\n    h_y                : torch.complex64\n                         1D complex kernel in spatial domain along Y axis.\n    \"\"\"\n    k = wavenumber(wavelength)\n    distance = torch.as_tensor(distance, device = device)\n    length_x, length_y = (\n                          torch.tensor(dx * nu, device = device),\n                          torch.tensor(dx * nv, device = device)\n                         )\n    x = torch.linspace(- length_x / 2., length_x / 2., nu * scale, device = device)\n    y = torch.linspace(- length_y / 2., length_y / 2., nv * scale, device = device)\n    wxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[0], device = device).unsqueeze(0).unsqueeze(0)\n    wys = torch.linspace(- dx / 2., dx / 2., aperture_samples[1], device = device).unsqueeze(0).unsqueeze(-1)\n    pxs = torch.linspace(- dx / 2., dx / 2., aperture_samples[2], device = device).unsqueeze(0).unsqueeze(-1)\n    pys = torch.linspace(- dx / 2., dx / 2., aperture_samples[3], device = device).unsqueeze(0).unsqueeze(0)\n    wxs = (wxs - pxs).reshape(1, -1).unsqueeze(-1)\n    wys = (wys - pys).reshape(1, -1).unsqueeze(1)\n\n    X = x.unsqueeze(-1).unsqueeze(-1)\n    Y = y[y.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)\n    r_x = (X + wxs) ** 2\n    r_y = (Y + wys) ** 2\n    r = r_x + r_y\n    h_x = torch.exp(1j * k / (2 * distance) * r)\n    h_x = torch.sum(h_x, axis = (1, 2))\n\n    if nu != nv:\n        X = x[x.shape[0] // 2].unsqueeze(-1).unsqueeze(-1)\n        Y = y.unsqueeze(-1).unsqueeze(-1)\n        r_x = (X + wxs) ** 2\n        r_y = (Y + wys) ** 2\n        r = r_x + r_y\n        h_y = torch.exp(1j * k * r / (2 * distance))\n        h_y = torch.sum(h_y, axis = (1, 2))\n    else:\n        h_y = h_x.detach().clone()\n    h = torch.exp(1j * k * distance) / (1j * wavelength * distance) * h_x.unsqueeze(1) * h_y.unsqueeze(0)\n    H = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(h))) * dx ** 2 / aperture_samples[0] / aperture_samples[1] / aperture_samples[2] / aperture_samples[3]\n    return H, h, h_x, h_y\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.get_transfer_function_fresnel_kernel","title":"get_transfer_function_fresnel_kernel(nu, nv, dx=8e-06, wavelength=5.15e-07, distance=0.0, device=torch.device('cpu'))","text":"

Helper function for odak.learn.wave.transfer_function_fresnel.

Parameters:

  • nu \u2013
                 Resolution at X axis in pixels.\n
  • nv \u2013
                 Resolution at Y axis in pixels.\n
  • dx \u2013
                 Pixel pitch in meters.\n
  • wavelength \u2013
                 Wavelength in meters.\n
  • distance \u2013
                 Distance in meters.\n
  • device \u2013
                 Device, for more see torch.device().\n

Returns:

  • H ( complex64 ) \u2013

    Complex kernel in Fourier domain.

Source code in odak/learn/wave/classical.py
def get_transfer_function_fresnel_kernel(nu, nv, dx = 8e-6, wavelength = 515e-9, distance = 0., device = torch.device('cpu')):\n    \"\"\"\n    Helper function for odak.learn.wave.transfer_function_fresnel.\n\n    Parameters\n    ----------\n    nu                 : int\n                         Resolution at X axis in pixels.\n    nv                 : int\n                         Resolution at Y axis in pixels.\n    dx                 : float\n                         Pixel pitch in meters.\n    wavelength         : float\n                         Wavelength in meters.\n    distance           : float\n                         Distance in meters.\n    device             : torch.device\n                         Device, for more see torch.device().\n\n\n    Returns\n    -------\n    H                  : torch.complex64\n                         Complex kernel in Fourier domain.\n    \"\"\"\n    distance = torch.tensor([distance]).to(device)\n    fx = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nu, dtype = torch.float32, device = device)\n    fy = torch.linspace(-1. / 2. /dx, 1. / 2. /dx, nv, dtype = torch.float32, device = device)\n    FY, FX = torch.meshgrid(fx, fy, indexing = 'ij')\n    k = wavenumber(wavelength)\n    H = torch.exp(-1j * distance * (k - torch.pi * wavelength * (FX ** 2 + FY ** 2)))\n    return H\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.impulse_response_fresnel","title":"impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n
  • scale \u2013
               Resolution factor to scale generated kernel.\n
  • samples \u2013
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n    scale            : int\n                       Resolution factor to scale generated kernel.\n    samples          : list\n                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Impulse Response Fresnel',\n                               device = field.device,\n                               scale = scale,\n                               samples = samples\n                              )\n    if scale > 1:\n        field_amplitude = calculate_amplitude(field)\n        field_phase = calculate_phase(field)\n        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)\n        field_scale_phase = torch.zeros_like(field_scale_amplitude)\n        field_scale_amplitude[::scale, ::scale] = field_amplitude\n        field_scale_phase[::scale, ::scale] = field_phase\n        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)\n    else:\n        field_scale = field\n    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.incoherent_angular_spectrum","title":"incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)","text":"

A definition to calculate incoherent beam propagation with Angular Spectrum method.

Parameters:

  • field \u2013
               Complex field [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):\n    \"\"\"\n    A definition to calculate incoherent beam propagation with Angular Spectrum method.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field [m x n].\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Incoherent Angular Spectrum',\n                               device = field.device\n                              )\n    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.point_wise","title":"point_wise(target, wavelength, distance, dx, device, lens_size=401)","text":"

Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. \"Holographic near-eye displays for virtual and augmented reality.\" ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.

Parameters:

  • target \u2013
               float input target to be converted into a hologram (Target should be in range of 0 and 1).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • device \u2013
               Device type (cuda or cpu)`.\n
  • lens_size \u2013
               Size of lens for masking sub holograms(in pixels).\n

Returns:

  • hologram ( cfloat ) \u2013

    Calculated complex hologram.

Source code in odak/learn/wave/classical.py
def point_wise(target, wavelength, distance, dx, device, lens_size=401):\n    \"\"\"\n    Naive point-wise hologram calculation method. For more information, refer to Maimone, Andrew, Andreas Georgiou, and Joel S. Kollin. \"Holographic near-eye displays for virtual and augmented reality.\" ACM Transactions on Graphics (TOG) 36.4 (2017): 1-16.\n\n    Parameters\n    ----------\n    target           : torch.float\n                       float input target to be converted into a hologram (Target should be in range of 0 and 1).\n    wavelength       : float\n                       Wavelength of the electric field.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    device           : torch.device\n                       Device type (cuda or cpu)`.\n    lens_size        : int\n                       Size of lens for masking sub holograms(in pixels).\n\n    Returns\n    -------\n    hologram         : torch.cfloat\n                       Calculated complex hologram.\n    \"\"\"\n    target = zero_pad(target)\n    nx, ny = target.shape\n    k = wavenumber(wavelength)\n    ones = torch.ones(target.shape, requires_grad=False).to(device)\n    x = torch.linspace(-nx/2, nx/2, nx).to(device)\n    y = torch.linspace(-ny/2, ny/2, ny).to(device)\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    Z = (X**2+Y**2)**0.5\n    mask = (torch.abs(Z) <= lens_size)\n    mask[mask > 1] = 1\n    fz = quadratic_phase_function(nx, ny, k, focal=-distance, dx=dx).to(device)\n    A = torch.nan_to_num(target**0.5, nan=0.0)\n    fz = mask*fz\n    FA = torch.fft.fft2(torch.fft.fftshift(A))\n    FFZ = torch.fft.fft2(torch.fft.fftshift(fz))\n    H = torch.mul(FA, FFZ)\n    hologram = torch.fft.ifftshift(torch.fft.ifft2(H))\n    hologram = crop_center(hologram)\n    return hologram\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.propagate_beam","title":"propagate_beam(field, k, distance, dx, wavelength, propagation_type='Bandlimited Angular Spectrum', kernel=None, zero_padding=[True, False, True], aperture=1.0, scale=1, samples=[20, 20, 5, 5])","text":"

Definitions for various beam propagation methods mostly in accordence with \"Computational Fourier Optics\" by David Vuelz.

Parameters:

  • field \u2013
               Complex field [m x n].\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • propagation_type (str, default: 'Bandlimited Angular Spectrum' ) \u2013
               Type of the propagation.\n           The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.\n
  • kernel \u2013
               Custom complex kernel.\n
  • zero_padding \u2013
               Zero padding the input field if the first item in the list set True.\n           Zero padding in the Fourier domain if the second item in the list set to True.\n           Cropping the result with half resolution if the third item in the list is set to true.\n           Note that in Fraunhofer propagation, setting the second item True or False will have no effect.\n
  • aperture \u2013
               Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.\n           If provided as a floating point 1, there will be no aperture in Fourier domain.\n
  • scale \u2013
               Resolution factor to scale generated kernel.\n
  • samples \u2013
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n

Returns:

  • result ( complex ) \u2013

    Final complex field [m x n].

Source code in odak/learn/wave/classical.py
def propagate_beam(\n                   field,\n                   k,\n                   distance,\n                   dx,\n                   wavelength,\n                   propagation_type='Bandlimited Angular Spectrum',\n                   kernel = None,\n                   zero_padding = [True, False, True],\n                   aperture = 1.,\n                   scale = 1,\n                   samples = [20, 20, 5, 5]\n                  ):\n    \"\"\"\n    Definitions for various beam propagation methods mostly in accordence with \"Computational Fourier Optics\" by David Vuelz.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field [m x n].\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    propagation_type : str\n                       Type of the propagation.\n                       The options are Impulse Response Fresnel, Transfer Function Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, Fraunhofer.\n    kernel           : torch.complex\n                       Custom complex kernel.\n    zero_padding     : list\n                       Zero padding the input field if the first item in the list set True.\n                       Zero padding in the Fourier domain if the second item in the list set to True.\n                       Cropping the result with half resolution if the third item in the list is set to true.\n                       Note that in Fraunhofer propagation, setting the second item True or False will have no effect.\n    aperture         : torch.tensor\n                       Aperture at Fourier domain default:[2m x 2n], otherwise depends on `zero_padding`.\n                       If provided as a floating point 1, there will be no aperture in Fourier domain.\n    scale            : int\n                       Resolution factor to scale generated kernel.\n    samples          : list\n                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for a hologram pixel and second two is for an image plane pixel.\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field [m x n].\n    \"\"\"\n    if zero_padding[0]:\n        field = zero_pad(field)\n    if propagation_type == 'Angular Spectrum':\n        result = angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Bandlimited Angular Spectrum':\n        result = band_limited_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Impulse Response Fresnel':\n        result = impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)\n    elif propagation_type == 'Seperable Impulse Response Fresnel':\n        result = seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture, scale = scale, samples = samples)\n    elif propagation_type == 'Transfer Function Fresnel':\n        result = transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'custom':\n        result = custom(field, kernel, zero_padding[1], aperture = aperture)\n    elif propagation_type == 'Fraunhofer':\n        result = fraunhofer(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Incoherent Angular Spectrum':\n        result = incoherent_angular_spectrum(field, k, distance, dx, wavelength, zero_padding[1], aperture = aperture)\n    else:\n        logging.warning('Propagation type not recognized')\n        assert True == False\n    if zero_padding[2]:\n        result = crop_center(result)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.seperable_impulse_response_fresnel","title":"seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0, scale=1, samples=[20, 20, 5, 5])","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n
  • scale \u2013
               Resolution factor to scale generated kernel.\n
  • samples \u2013
               When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def seperable_impulse_response_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1., scale = 1, samples = [20, 20, 5, 5]):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation for a rectangular aperture using the seperable property.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n    scale            : int\n                       Resolution factor to scale generated kernel.\n    samples          : list\n                       When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Seperable Impulse Response Fresnel',\n                               device = field.device,\n                               scale = scale,\n                               samples = samples\n                              )\n    if scale > 1:\n        field_amplitude = calculate_amplitude(field)\n        field_phase = calculate_phase(field)\n        field_scale_amplitude = torch.zeros(field.shape[-2] * scale, field.shape[-1] * scale, device = field.device)\n        field_scale_phase = torch.zeros_like(field_scale_amplitude)\n        field_scale_amplitude[::scale, ::scale] = field_amplitude\n        field_scale_phase[::scale, ::scale] = field_phase\n        field_scale = generate_complex_field(field_scale_amplitude, field_scale_phase)\n    else:\n        field_scale = field\n    result = custom(field_scale, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.shift_w_double_phase","title":"shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None)","text":"

Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in here and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.

Parameters:

  • phase \u2013
               Phase value of a phase-only hologram.\n
  • depth_shift \u2013
               Distance in meters.\n
  • pixel_pitch \u2013
               Pixel pitch size in meters.\n
  • wavelength \u2013
               Wavelength of light.\n
  • propagation_type (str, default: 'Transfer Function Fresnel' ) \u2013
               Beam propagation type. For more see odak.learn.wave.propagate_beam().\n
  • kernel_length \u2013
               Kernel length for the Gaussian blur kernel.\n
  • sigma \u2013
               Standard deviation for the Gaussian blur kernel.\n
  • amplitude \u2013
               Amplitude value of a complex hologram.\n
Source code in odak/learn/wave/classical.py
def shift_w_double_phase(phase, depth_shift, pixel_pitch, wavelength, propagation_type='Transfer Function Fresnel', kernel_length=4, sigma=0.5, amplitude=None):\n    \"\"\"\n    Shift a phase-only hologram by propagating the complex hologram and double phase principle. Coded following in [here](https://github.com/liangs111/tensor_holography/blob/6fdb26561a4e554136c579fa57788bb5fc3cac62/optics.py#L131-L207) and Shi, L., Li, B., Kim, C., Kellnhofer, P., & Matusik, W. (2021). Towards real-time photorealistic 3D holography with deep neural networks. Nature, 591(7849), 234-239.\n\n    Parameters\n    ----------\n    phase            : torch.tensor\n                       Phase value of a phase-only hologram.\n    depth_shift      : float\n                       Distance in meters.\n    pixel_pitch      : float\n                       Pixel pitch size in meters.\n    wavelength       : float\n                       Wavelength of light.\n    propagation_type : str\n                       Beam propagation type. For more see odak.learn.wave.propagate_beam().\n    kernel_length    : int\n                       Kernel length for the Gaussian blur kernel.\n    sigma            : float\n                       Standard deviation for the Gaussian blur kernel.\n    amplitude        : torch.tensor\n                       Amplitude value of a complex hologram.\n    \"\"\"\n    if type(amplitude) == type(None):\n        amplitude = torch.ones_like(phase)\n    hologram = generate_complex_field(amplitude, phase)\n    k = wavenumber(wavelength)\n    hologram_padded = zero_pad(hologram)\n    shifted_field_padded = propagate_beam(\n                                          hologram_padded,\n                                          k,\n                                          depth_shift,\n                                          pixel_pitch,\n                                          wavelength,\n                                          propagation_type\n                                         )\n    shifted_field = crop_center(shifted_field_padded)\n    phase_shift = torch.exp(torch.tensor([-2 * torch.pi * depth_shift / wavelength]).to(phase.device))\n    shift = torch.cos(phase_shift) + 1j * torch.sin(phase_shift)\n    shifted_complex_hologram = shifted_field * shift\n\n    if kernel_length > 0 and sigma >0:\n        blur_kernel = generate_2d_gaussian(\n                                           [kernel_length, kernel_length],\n                                           [sigma, sigma]\n                                          ).to(phase.device)\n        blur_kernel = blur_kernel.unsqueeze(0)\n        blur_kernel = blur_kernel.unsqueeze(0)\n        field_imag = torch.imag(shifted_complex_hologram)\n        field_real = torch.real(shifted_complex_hologram)\n        field_imag = field_imag.unsqueeze(0)\n        field_imag = field_imag.unsqueeze(0)\n        field_real = field_real.unsqueeze(0)\n        field_real = field_real.unsqueeze(0)\n        field_imag = torch.nn.functional.conv2d(field_imag, blur_kernel, padding='same')\n        field_real = torch.nn.functional.conv2d(field_real, blur_kernel, padding='same')\n        shifted_complex_hologram = torch.complex(field_real, field_imag)\n        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)\n        shifted_complex_hologram = shifted_complex_hologram.squeeze(0)\n\n    shifted_amplitude = calculate_amplitude(shifted_complex_hologram)\n    shifted_amplitude = shifted_amplitude / torch.amax(shifted_amplitude, [0,1])\n\n    shifted_phase = calculate_phase(shifted_complex_hologram)\n    phase_zero_mean = shifted_phase - torch.mean(shifted_phase)\n\n    phase_offset = torch.arccos(shifted_amplitude)\n    phase_low = phase_zero_mean - phase_offset\n    phase_high = phase_zero_mean + phase_offset\n\n    phase_only = torch.zeros_like(phase)\n    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]\n    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]\n    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]\n    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]\n    return phase_only\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.stochastic_gradient_descent","title":"stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type='Bandlimited Angular Spectrum', n_iteration=100, loss_function=None, learning_rate=0.1)","text":"

Definition to generate phase and reconstruction from target image via stochastic gradient descent.

Parameters:

  • target \u2013
                        Target field amplitude [m x n].\n                    Keep the target values between zero and one.\n
  • wavelength \u2013
                        Set if the converted array requires gradient.\n
  • distance \u2013
                        Hologram plane distance wrt SLM plane.\n
  • pixel_pitch \u2013
                        SLM pixel pitch in meters.\n
  • propagation_type \u2013
                        Type of the propagation (see odak.learn.wave.propagate_beam()).\n
  • n_iteration \u2013
                        Number of iteration.\n
  • loss_function \u2013
                        If none it is set to be l2 loss.\n
  • learning_rate \u2013
                        Learning rate.\n

Returns:

  • hologram ( Tensor ) \u2013

    Phase only hologram as torch array

  • reconstruction_intensity ( Tensor ) \u2013

    Reconstruction as torch array

Source code in odak/learn/wave/classical.py
def stochastic_gradient_descent(target, wavelength, distance, pixel_pitch, propagation_type = 'Bandlimited Angular Spectrum', n_iteration = 100, loss_function = None, learning_rate = 0.1):\n    \"\"\"\n    Definition to generate phase and reconstruction from target image via stochastic gradient descent.\n\n    Parameters\n    ----------\n    target                    : torch.Tensor\n                                Target field amplitude [m x n].\n                                Keep the target values between zero and one.\n    wavelength                : double\n                                Set if the converted array requires gradient.\n    distance                  : double\n                                Hologram plane distance wrt SLM plane.\n    pixel_pitch               : float\n                                SLM pixel pitch in meters.\n    propagation_type          : str\n                                Type of the propagation (see odak.learn.wave.propagate_beam()).\n    n_iteration:              : int\n                                Number of iteration.\n    loss_function:            : function\n                                If none it is set to be l2 loss.\n    learning_rate             : float\n                                Learning rate.\n\n    Returns\n    -------\n    hologram                  : torch.Tensor\n                                Phase only hologram as torch array\n\n    reconstruction_intensity  : torch.Tensor\n                                Reconstruction as torch array\n\n    \"\"\"\n    phase = torch.randn_like(target, requires_grad = True)\n    k = wavenumber(wavelength)\n    optimizer = torch.optim.Adam([phase], lr = learning_rate)\n    if type(loss_function) == type(None):\n        loss_function = torch.nn.MSELoss()\n    t = tqdm(range(n_iteration), leave = False, dynamic_ncols = True)\n    for i in t:\n        optimizer.zero_grad()\n        hologram = generate_complex_field(1., phase)\n        reconstruction = propagate_beam(\n                                        hologram, \n                                        k, \n                                        distance, \n                                        pixel_pitch, \n                                        wavelength, \n                                        propagation_type, \n                                        zero_padding = [True, False, True]\n                                       )\n        reconstruction_intensity = calculate_amplitude(reconstruction) ** 2\n        loss = loss_function(reconstruction_intensity, target)\n        description = \"Loss:{:.4f}\".format(loss.item())\n        loss.backward(retain_graph = True)\n        optimizer.step()\n        t.set_description(description)\n    logging.warning(description)\n    torch.no_grad()\n    hologram = generate_complex_field(1., phase)\n    reconstruction = propagate_beam(\n                                    hologram, \n                                    k, \n                                    distance, \n                                    pixel_pitch, \n                                    wavelength, \n                                    propagation_type, \n                                    zero_padding = [True, False, True]\n                                   )\n    return hologram, reconstruction\n
"},{"location":"odak/learn_wave/#odak.learn.wave.classical.transfer_function_fresnel","title":"transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding=False, aperture=1.0)","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • zero_padding \u2013
               Zero pad in Fourier domain.\n
  • aperture \u2013
               Fourier domain aperture (e.g., pinhole in a typical holographic display).\n           The default is one, but an aperture could be as large as input field [m x n].\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/learn/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength, zero_padding = False, aperture = 1.):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : torch.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    zero_padding     : bool\n                       Zero pad in Fourier domain.\n    aperture         : torch.tensor\n                       Fourier domain aperture (e.g., pinhole in a typical holographic display).\n                       The default is one, but an aperture could be as large as input field [m x n].\n\n\n    Returns\n    -------\n    result           : torch.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    H = get_propagation_kernel(\n                               nu = field.shape[-2], \n                               nv = field.shape[-1], \n                               dx = dx, \n                               wavelength = wavelength, \n                               distance = distance, \n                               propagation_type = 'Transfer Function Fresnel',\n                               device = field.device\n                              )\n    result = custom(field, H, zero_padding = zero_padding, aperture = aperture)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.lens.blazed_grating","title":"blazed_grating(nx, ny, levels=2, axis='x')","text":"

A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario Garc\u00eda, et al. \"High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings.\" Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. \"High efficiency electrically-addressable phase-only spatial light modulator.\" optical review 6 (1999): 339-344.

Parameters:

  • nx \u2013
           Size of the output along X.\n
  • ny \u2013
           Size of the output along Y.\n
  • levels \u2013
           Number of pixels.\n
  • axis \u2013
           Axis of glazed grating. It could be `x` or `y`.\n
Source code in odak/learn/wave/lens.py
def blazed_grating(nx, ny, levels = 2, axis = 'x'):\n    \"\"\"\n    A defininition to generate a blazed grating (also known as ramp grating). For more consult de Blas, Mario Garc\u00eda, et al. \"High resolution 2D beam steerer made from cascaded 1D liquid crystal phase gratings.\" Scientific Reports 12.1 (2022): 5145 and Igasaki, Yasunori, et al. \"High efficiency electrically-addressable phase-only spatial light modulator.\" optical review 6 (1999): 339-344.\n\n\n    Parameters\n    ----------\n    nx           : int\n                   Size of the output along X.\n    ny           : int\n                   Size of the output along Y.\n    levels       : int\n                   Number of pixels.\n    axis         : str\n                   Axis of glazed grating. It could be `x` or `y`.\n\n    \"\"\"\n    if levels < 2:\n        levels = 2\n    x = (torch.abs(torch.arange(-nx, 0)) % levels) / levels * (2 * np.pi)\n    y = (torch.abs(torch.arange(-ny, 0)) % levels) / levels * (2 * np.pi)\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    if axis == 'x':\n        blazed_grating = torch.exp(1j * X)\n    elif axis == 'y':\n        blazed_grating = torch.exp(1j * Y)\n    return blazed_grating\n
"},{"location":"odak/learn_wave/#odak.learn.wave.lens.linear_grating","title":"linear_grating(nx, ny, every=2, add=None, axis='x')","text":"

A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • every \u2013
         Add the add value at every given number.\n
  • add \u2013
         Angle to be added.\n
  • axis \u2013
         Axis eiter X,Y or both.\n

Returns:

  • field ( tensor ) \u2013

    Linear grating term.

Source code in odak/learn/wave/lens.py
def linear_grating(nx, ny, every = 2, add = None, axis = 'x'):\n    \"\"\"\n    A definition to generate a linear grating. This could also be interpreted as two levels blazed grating. For more on blazed gratings see odak.learn.wave.blazed_grating() function.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    every      : int\n                 Add the add value at every given number.\n    add        : float\n                 Angle to be added.\n    axis       : string\n                 Axis eiter X,Y or both.\n\n    Returns\n    ----------\n    field      : torch.tensor\n                 Linear grating term.\n    \"\"\"\n    if isinstance(add, type(None)):\n        add = np.pi\n    grating = torch.zeros((nx, ny), dtype=torch.complex64)\n    if axis == 'x':\n        grating[::every, :] = torch.exp(torch.tensor(1j*add))\n    if axis == 'y':\n        grating[:, ::every] = torch.exp(torch.tensor(1j*add))\n    if axis == 'xy':\n        checker = np.indices((nx, ny)).sum(axis=0) % every\n        checker = torch.from_numpy(checker)\n        checker += 1\n        checker = checker % 2\n        grating = torch.exp(1j*checker*add)\n    return grating\n
"},{"location":"odak/learn_wave/#odak.learn.wave.lens.prism_grating","title":"prism_grating(nx, ny, k, angle, dx=0.001, axis='x', phase_offset=0.0)","text":"

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engstr\u00f6m, David, et al. \"Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator.\" Optics express 16.22 (2008): 18275-18287. for more.

Parameters:

  • nx \u2013
           Size of the output along X.\n
  • ny \u2013
           Size of the output along Y.\n
  • k \u2013
           See odak.wave.wavenumber for more.\n
  • angle \u2013
           Tilt angle of the prism in degrees.\n
  • dx \u2013
           Pixel pitch.\n
  • axis \u2013
           Axis of the prism.\n
  • phase_offset (float, default: 0.0 ) \u2013
           Phase offset in angles. Default is zero.\n

Returns:

  • prism ( tensor ) \u2013

    Generated phase function for a prism.

Source code in odak/learn/wave/lens.py
def prism_grating(nx, ny, k, angle, dx = 0.001, axis = 'x', phase_offset = 0.):\n    \"\"\"\n    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book or Engstr\u00f6m, David, et al. \"Improved beam steering accuracy of a single beam with a 1D phase-only spatial light modulator.\" Optics express 16.22 (2008): 18275-18287. for more.\n\n    Parameters\n    ----------\n    nx           : int\n                   Size of the output along X.\n    ny           : int\n                   Size of the output along Y.\n    k            : odak.wave.wavenumber\n                   See odak.wave.wavenumber for more.\n    angle        : float\n                   Tilt angle of the prism in degrees.\n    dx           : float\n                   Pixel pitch.\n    axis         : str\n                   Axis of the prism.\n    phase_offset : float\n                   Phase offset in angles. Default is zero.\n\n    Returns\n    ----------\n    prism        : torch.tensor\n                   Generated phase function for a prism.\n    \"\"\"\n    angle = torch.deg2rad(torch.tensor([angle]))\n    phase_offset = torch.deg2rad(torch.tensor([phase_offset]))\n    x = torch.arange(0, nx) * dx\n    y = torch.arange(0, ny) * dx\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    if axis == 'y':\n        phase = k * torch.sin(angle) * Y + phase_offset\n        prism = torch.exp(-1j * phase)\n    elif axis == 'x':\n        phase = k * torch.sin(angle) * X + phase_offset\n        prism = torch.exp(-1j * phase)\n    return prism\n
"},{"location":"odak/learn_wave/#odak.learn.wave.lens.quadratic_phase_function","title":"quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])","text":"

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • focal \u2013
         Focal length of the quadratic phase function.\n
  • dx \u2013
         Pixel pitch.\n
  • offset \u2013
         Deviation from the center along X and Y axes.\n

Returns:

  • function ( tensor ) \u2013

    Generated quadratic phase function.

Source code in odak/learn/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):\n    \"\"\" \n    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    focal      : float\n                 Focal length of the quadratic phase function.\n    dx         : float\n                 Pixel pitch.\n    offset     : list\n                 Deviation from the center along X and Y axes.\n\n    Returns\n    -------\n    function   : torch.tensor\n                 Generated quadratic phase function.\n    \"\"\"\n    size = [nx, ny]\n    x = torch.linspace(-size[0] * dx / 2, size[0] * dx / 2, size[0]) - offset[1] * dx\n    y = torch.linspace(-size[1] * dx / 2, size[1] * dx / 2, size[1]) - offset[0] * dx\n    X, Y = torch.meshgrid(x, y, indexing='ij')\n    Z = X**2 + Y**2\n    qwf = torch.exp(-0.5j * k / focal * Z)\n    return qwf\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss","title":"multiplane_loss","text":"

Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class multiplane_loss():\n    \"\"\"\n    Loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.\n    \"\"\"\n\n    def __init__(self, target_image, target_depth, blur_ratio = 0.25, \n                 target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], \n                 multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):\n        \"\"\"\n        Parameters\n        ----------\n        target_image      : torch.tensor\n                            Color target image [3 x m x n].\n        target_depth      : torch.tensor\n                            Monochrome target depth, same resolution as target_image.\n        target_blur_size  : int\n                            Maximum target blur size.\n        blur_ratio        : float\n                            Blur ratio, a value between zero and one.\n        number_of_planes  : int\n                            Number of planes.\n        weights           : list\n                            Weights of the loss function.\n        multiplier        : float\n                            Multiplier to multipy with targets.\n        scheme            : str\n                            The type of the loss, `naive` without defocus or `defocus` with defocus.\n        reduction         : str\n                            Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n        device            : torch.device\n                            Device to be used (e.g., cuda, cpu, opencl).\n        \"\"\"\n        self.device = device\n        self.target_image     = target_image.float().to(self.device)\n        self.target_depth     = target_depth.float().to(self.device)\n        self.target_blur_size = target_blur_size\n        if self.target_blur_size % 2 == 0:\n            self.target_blur_size += 1\n        self.number_of_planes = number_of_planes\n        self.multiplier       = multiplier\n        self.weights          = weights\n        self.reduction        = reduction\n        self.blur_ratio       = blur_ratio\n        self.set_targets()\n        if scheme == 'defocus':\n            self.add_defocus_blur()\n        self.loss_function = torch.nn.MSELoss(reduction = self.reduction)\n\n    def get_targets(self):\n        \"\"\"\n        Returns\n        -------\n        targets           : torch.tensor\n                            Returns a copy of the targets.\n        target_depth      : torch.tensor\n                            Returns a copy of the normalized quantized depth map.\n\n        \"\"\"\n        divider = self.number_of_planes - 1\n        if divider == 0:\n            divider = 1\n        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider\n\n\n    def set_targets(self):\n        \"\"\"\n        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.\n        \"\"\"\n        self.target_depth = self.target_depth * (self.number_of_planes - 1)\n        self.target_depth = torch.round(self.target_depth, decimals = 0)\n        self.targets      = torch.zeros(\n                                        self.number_of_planes,\n                                        self.target_image.shape[0],\n                                        self.target_image.shape[1],\n                                        self.target_image.shape[2],\n                                        requires_grad = False,\n                                        device = self.device\n                                       )\n        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)\n        self.masks        = torch.zeros_like(self.targets)\n        for i in range(self.number_of_planes):\n            for ch in range(self.target_image.shape[0]):\n                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)\n                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)\n                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)\n                new_target = self.target_image[ch] * mask\n                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()\n                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)\n                self.masks[i, ch] = mask.detach().clone() \n\n\n    def add_defocus_blur(self):\n        \"\"\"\n        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.\n        \"\"\"\n        kernel_length = [self.target_blur_size, self.target_blur_size ]\n        for ch in range(self.target_image.shape[0]):\n            targets_cache = self.targets[:, ch].detach().clone()\n            target = torch.sum(targets_cache, axis = 0)\n            for i in range(self.number_of_planes):\n                defocus = torch.zeros_like(targets_cache[i])\n                for j in range(self.number_of_planes):\n                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]\n                    if torch.sum(targets_cache[j]) > 0:\n                        if i == j:\n                            nsigma = [0., 0.]\n                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)\n                        kernel = kernel / torch.sum(kernel)\n                        kernel = kernel.unsqueeze(0).unsqueeze(0)\n                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)\n                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')\n                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])\n                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])\n                self.targets[i, ch] = defocus\n        self.targets = self.targets.detach().clone() * self.multiplier\n\n\n    def __call__(self, image, target, plane_id = None):\n        \"\"\"\n        Calculates the multiplane loss against a given target.\n\n        Parameters\n        ----------\n        image         : torch.tensor\n                        Image to compare with a target [3 x m x n].\n        target        : torch.tensor\n                        Target image for comparison [3 x m x n].\n        plane_id      : int\n                        Number of the plane under test.\n\n        Returns\n        -------\n        loss          : torch.tensor\n                        Computed loss.\n        \"\"\"\n        l2 = self.weights[0] * self.loss_function(image, target)\n        if isinstance(plane_id, type(None)):\n            mask = self.masks\n        else:\n            mask= self.masks[plane_id, :]\n        l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)\n        l2_cor = self.weights[2] * self.loss_function(image * target, target * target)\n        loss = l2 + l2_mask + l2_cor\n        return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss.__call__","title":"__call__(image, target, plane_id=None)","text":"

Calculates the multiplane loss against a given target.

Parameters:

  • image \u2013
            Image to compare with a target [3 x m x n].\n
  • target \u2013
            Target image for comparison [3 x m x n].\n
  • plane_id \u2013
            Number of the plane under test.\n

Returns:

  • loss ( tensor ) \u2013

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):\n    \"\"\"\n    Calculates the multiplane loss against a given target.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to compare with a target [3 x m x n].\n    target        : torch.tensor\n                    Target image for comparison [3 x m x n].\n    plane_id      : int\n                    Number of the plane under test.\n\n    Returns\n    -------\n    loss          : torch.tensor\n                    Computed loss.\n    \"\"\"\n    l2 = self.weights[0] * self.loss_function(image, target)\n    if isinstance(plane_id, type(None)):\n        mask = self.masks\n    else:\n        mask= self.masks[plane_id, :]\n    l2_mask = self.weights[1] * self.loss_function(image * mask, target * mask)\n    l2_cor = self.weights[2] * self.loss_function(image * target, target * target)\n    loss = l2 + l2_mask + l2_cor\n    return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss.__init__","title":"__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, weights=[1.0, 2.1, 0.6], multiplier=1.0, scheme='defocus', reduction='mean', device=torch.device('cpu'))","text":"

Parameters:

  • target_image \u2013
                Color target image [3 x m x n].\n
  • target_depth \u2013
                Monochrome target depth, same resolution as target_image.\n
  • target_blur_size \u2013
                Maximum target blur size.\n
  • blur_ratio \u2013
                Blur ratio, a value between zero and one.\n
  • number_of_planes \u2013
                Number of planes.\n
  • weights \u2013
                Weights of the loss function.\n
  • multiplier \u2013
                Multiplier to multipy with targets.\n
  • scheme \u2013
                The type of the loss, `naive` without defocus or `defocus` with defocus.\n
  • reduction \u2013
                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n
  • device \u2013
                Device to be used (e.g., cuda, cpu, opencl).\n
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, \n             target_blur_size = 10, number_of_planes = 4, weights = [1., 2.1, 0.6], \n             multiplier = 1., scheme = 'defocus', reduction = 'mean', device = torch.device('cpu')):\n    \"\"\"\n    Parameters\n    ----------\n    target_image      : torch.tensor\n                        Color target image [3 x m x n].\n    target_depth      : torch.tensor\n                        Monochrome target depth, same resolution as target_image.\n    target_blur_size  : int\n                        Maximum target blur size.\n    blur_ratio        : float\n                        Blur ratio, a value between zero and one.\n    number_of_planes  : int\n                        Number of planes.\n    weights           : list\n                        Weights of the loss function.\n    multiplier        : float\n                        Multiplier to multipy with targets.\n    scheme            : str\n                        The type of the loss, `naive` without defocus or `defocus` with defocus.\n    reduction         : str\n                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n    device            : torch.device\n                        Device to be used (e.g., cuda, cpu, opencl).\n    \"\"\"\n    self.device = device\n    self.target_image     = target_image.float().to(self.device)\n    self.target_depth     = target_depth.float().to(self.device)\n    self.target_blur_size = target_blur_size\n    if self.target_blur_size % 2 == 0:\n        self.target_blur_size += 1\n    self.number_of_planes = number_of_planes\n    self.multiplier       = multiplier\n    self.weights          = weights\n    self.reduction        = reduction\n    self.blur_ratio       = blur_ratio\n    self.set_targets()\n    if scheme == 'defocus':\n        self.add_defocus_blur()\n    self.loss_function = torch.nn.MSELoss(reduction = self.reduction)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss.add_defocus_blur","title":"add_defocus_blur()","text":"

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):\n    \"\"\"\n    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.\n    \"\"\"\n    kernel_length = [self.target_blur_size, self.target_blur_size ]\n    for ch in range(self.target_image.shape[0]):\n        targets_cache = self.targets[:, ch].detach().clone()\n        target = torch.sum(targets_cache, axis = 0)\n        for i in range(self.number_of_planes):\n            defocus = torch.zeros_like(targets_cache[i])\n            for j in range(self.number_of_planes):\n                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]\n                if torch.sum(targets_cache[j]) > 0:\n                    if i == j:\n                        nsigma = [0., 0.]\n                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)\n                    kernel = kernel / torch.sum(kernel)\n                    kernel = kernel.unsqueeze(0).unsqueeze(0)\n                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)\n                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')\n                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])\n                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])\n            self.targets[i, ch] = defocus\n    self.targets = self.targets.detach().clone() * self.multiplier\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss.get_targets","title":"get_targets()","text":"

Returns:

  • targets ( tensor ) \u2013

    Returns a copy of the targets.

  • target_depth ( tensor ) \u2013

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):\n    \"\"\"\n    Returns\n    -------\n    targets           : torch.tensor\n                        Returns a copy of the targets.\n    target_depth      : torch.tensor\n                        Returns a copy of the normalized quantized depth map.\n\n    \"\"\"\n    divider = self.number_of_planes - 1\n    if divider == 0:\n        divider = 1\n    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.multiplane_loss.set_targets","title":"set_targets()","text":"

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):\n    \"\"\"\n    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.\n    \"\"\"\n    self.target_depth = self.target_depth * (self.number_of_planes - 1)\n    self.target_depth = torch.round(self.target_depth, decimals = 0)\n    self.targets      = torch.zeros(\n                                    self.number_of_planes,\n                                    self.target_image.shape[0],\n                                    self.target_image.shape[1],\n                                    self.target_image.shape[2],\n                                    requires_grad = False,\n                                    device = self.device\n                                   )\n    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)\n    self.masks        = torch.zeros_like(self.targets)\n    for i in range(self.number_of_planes):\n        for ch in range(self.target_image.shape[0]):\n            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)\n            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)\n            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)\n            new_target = self.target_image[ch] * mask\n            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()\n            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)\n            self.masks[i, ch] = mask.detach().clone() \n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss","title":"perceptual_multiplane_loss","text":"

Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.

Source code in odak/learn/wave/loss.py
class perceptual_multiplane_loss():\n    \"\"\"\n    Perceptual loss function for computing loss in multiplanar images. Unlike, previous methods, this loss function accounts for defocused parts of an image.\n    \"\"\"\n\n    def __init__(self, target_image, target_depth, blur_ratio = 0.25, \n                 target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', \n                 base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},\n                 additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):\n        \"\"\"\n        Parameters\n        ----------\n        target_image            : torch.tensor\n                                    Color target image [3 x m x n].\n        target_depth            : torch.tensor\n                                    Monochrome target depth, same resolution as target_image.\n        target_blur_size        : int\n                                    Maximum target blur size.\n        blur_ratio              : float\n                                    Blur ratio, a value between zero and one.\n        number_of_planes        : int\n                                    Number of planes.\n        multiplier              : float\n                                    Multiplier to multipy with targets.\n        scheme                  : str\n                                    The type of the loss, `naive` without defocus or `defocus` with defocus.\n        base_loss_weights       : list\n                                    Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.\n        additional_loss_weights : dict\n                                    Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.\n        reduction               : str\n                                    Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n        return_components       : bool\n                                    If True (False by default), returns the components of the loss as a dict.\n        device                  : torch.device\n                                    Device to be used (e.g., cuda, cpu, opencl).\n        \"\"\"\n        self.device = device\n        self.target_image     = target_image.float().to(self.device)\n        self.target_depth     = target_depth.float().to(self.device)\n        self.target_blur_size = target_blur_size\n        if self.target_blur_size % 2 == 0:\n            self.target_blur_size += 1\n        self.number_of_planes = number_of_planes\n        self.multiplier       = multiplier\n        self.reduction        = reduction\n        if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:\n            logging.warning(\"Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.\")\n            self.reduction = 'mean'\n        self.blur_ratio       = blur_ratio\n        self.set_targets()\n        if scheme == 'defocus':\n            self.add_defocus_blur()\n        self.base_loss_weights = base_loss_weights\n        self.additional_loss_weights = additional_loss_weights\n        self.return_components = return_components\n        self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)\n        self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)\n        for key in self.additional_loss_weights.keys():\n            if key == 'cvvdp':\n                self.cvvdp = CVVDP()\n            if key == 'fvvdp':\n                self.fvvdp = FVVDP()\n            if key == 'lpips':\n                self.lpips = LPIPS()\n            if key == 'psnr':\n                self.psnr = PSNR()\n            if key == 'ssim':\n                self.ssim = SSIM()\n            if key == 'msssim':\n                self.msssim = MSSSIM()\n\n    def get_targets(self):\n        \"\"\"\n        Returns\n        -------\n        targets           : torch.tensor\n                            Returns a copy of the targets.\n        target_depth      : torch.tensor\n                            Returns a copy of the normalized quantized depth map.\n\n        \"\"\"\n        divider = self.number_of_planes - 1\n        if divider == 0:\n            divider = 1\n        return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider\n\n\n    def set_targets(self):\n        \"\"\"\n        Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.\n        \"\"\"\n        self.target_depth = self.target_depth * (self.number_of_planes - 1)\n        self.target_depth = torch.round(self.target_depth, decimals = 0)\n        self.targets      = torch.zeros(\n                                        self.number_of_planes,\n                                        self.target_image.shape[0],\n                                        self.target_image.shape[1],\n                                        self.target_image.shape[2],\n                                        requires_grad = False,\n                                        device = self.device\n                                       )\n        self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)\n        self.masks        = torch.zeros_like(self.targets)\n        for i in range(self.number_of_planes):\n            for ch in range(self.target_image.shape[0]):\n                mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)\n                mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)\n                mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)\n                new_target = self.target_image[ch] * mask\n                self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()\n                self.targets[i, ch] = new_target.squeeze(0).squeeze(0)\n                self.masks[i, ch] = mask.detach().clone() \n\n\n    def add_defocus_blur(self):\n        \"\"\"\n        Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.\n        \"\"\"\n        kernel_length = [self.target_blur_size, self.target_blur_size ]\n        for ch in range(self.target_image.shape[0]):\n            targets_cache = self.targets[:, ch].detach().clone()\n            target = torch.sum(targets_cache, axis = 0)\n            for i in range(self.number_of_planes):\n                defocus = torch.zeros_like(targets_cache[i])\n                for j in range(self.number_of_planes):\n                    nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]\n                    if torch.sum(targets_cache[j]) > 0:\n                        if i == j:\n                            nsigma = [0., 0.]\n                        kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)\n                        kernel = kernel / torch.sum(kernel)\n                        kernel = kernel.unsqueeze(0).unsqueeze(0)\n                        target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)\n                        defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')\n                        defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])\n                        defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])\n                self.targets[i, ch] = defocus\n        self.targets = self.targets.detach().clone() * self.multiplier\n\n\n    def __call__(self, image, target, plane_id = None):\n        \"\"\"\n        Calculates the multiplane loss against a given target.\n\n        Parameters\n        ----------\n        image         : torch.tensor\n                        Image to compare with a target [3 x m x n].\n        target        : torch.tensor\n                        Target image for comparison [3 x m x n].\n        plane_id      : int\n                        Number of the plane under test.\n\n        Returns\n        -------\n        loss          : torch.tensor\n                        Computed loss.\n        \"\"\"\n        loss_components = {}\n        if isinstance(plane_id, type(None)):\n            mask = self.masks\n        else:\n            mask= self.masks[plane_id, :]\n        l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)\n        l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)\n        l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)\n        loss_components['l2'] = l2\n        loss_components['l2_mask'] = l2_mask\n        loss_components['l2_cor'] = l2_cor\n\n        l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)\n        l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)\n        l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)\n        loss_components['l1'] = l1\n        loss_components['l1_mask'] = l1_mask\n        loss_components['l1_cor'] = l1_cor\n\n        for key in self.additional_loss_weights.keys():\n            if key == 'cvvdp':\n                loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)\n                loss_components['cvvdp'] = loss_cvvdp\n            if key == 'fvvdp':\n                loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)\n                loss_components['fvvdp'] = loss_fvvdp\n            if key == 'lpips':\n                loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)\n                loss_components['lpips'] = loss_lpips\n            if key == 'psnr':\n                loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)\n                loss_components['psnr'] = loss_psnr\n            if key == 'ssim':\n                loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)\n                loss_components['ssim'] = loss_ssim\n            if key == 'msssim':\n                loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)\n                loss_components['msssim'] = loss_msssim\n\n        loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)\n\n        if self.return_components:\n            return loss, loss_components\n        return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss.__call__","title":"__call__(image, target, plane_id=None)","text":"

Calculates the multiplane loss against a given target.

Parameters:

  • image \u2013
            Image to compare with a target [3 x m x n].\n
  • target \u2013
            Target image for comparison [3 x m x n].\n
  • plane_id \u2013
            Number of the plane under test.\n

Returns:

  • loss ( tensor ) \u2013

    Computed loss.

Source code in odak/learn/wave/loss.py
def __call__(self, image, target, plane_id = None):\n    \"\"\"\n    Calculates the multiplane loss against a given target.\n\n    Parameters\n    ----------\n    image         : torch.tensor\n                    Image to compare with a target [3 x m x n].\n    target        : torch.tensor\n                    Target image for comparison [3 x m x n].\n    plane_id      : int\n                    Number of the plane under test.\n\n    Returns\n    -------\n    loss          : torch.tensor\n                    Computed loss.\n    \"\"\"\n    loss_components = {}\n    if isinstance(plane_id, type(None)):\n        mask = self.masks\n    else:\n        mask= self.masks[plane_id, :]\n    l2 = self.base_loss_weights['base_l2_loss'] * self.l2_loss_fn(image, target)\n    l2_mask = self.base_loss_weights['loss_l2_mask'] * self.l2_loss_fn(image * mask, target * mask)\n    l2_cor = self.base_loss_weights['loss_l2_cor'] * self.l2_loss_fn(image * target, target * target)\n    loss_components['l2'] = l2\n    loss_components['l2_mask'] = l2_mask\n    loss_components['l2_cor'] = l2_cor\n\n    l1 = self.base_loss_weights['base_l1_loss'] * self.l1_loss_fn(image, target)\n    l1_mask = self.base_loss_weights['loss_l1_mask'] * self.l1_loss_fn(image * mask, target * mask)\n    l1_cor = self.base_loss_weights['loss_l1_cor'] * self.l1_loss_fn(image * target, target * target)\n    loss_components['l1'] = l1\n    loss_components['l1_mask'] = l1_mask\n    loss_components['l1_cor'] = l1_cor\n\n    for key in self.additional_loss_weights.keys():\n        if key == 'cvvdp':\n            loss_cvvdp = self.additional_loss_weights['cvvdp'] * self.cvvdp(image, target)\n            loss_components['cvvdp'] = loss_cvvdp\n        if key == 'fvvdp':\n            loss_fvvdp = self.additional_loss_weights['fvvdp'] * self.fvvdp(image, target)\n            loss_components['fvvdp'] = loss_fvvdp\n        if key == 'lpips':\n            loss_lpips = self.additional_loss_weights['lpips'] * self.lpips(image, target)\n            loss_components['lpips'] = loss_lpips\n        if key == 'psnr':\n            loss_psnr = self.additional_loss_weights['psnr'] * self.psnr(image, target)\n            loss_components['psnr'] = loss_psnr\n        if key == 'ssim':\n            loss_ssim = self.additional_loss_weights['ssim'] * self.ssim(image, target)\n            loss_components['ssim'] = loss_ssim\n        if key == 'msssim':\n            loss_msssim = self.additional_loss_weights['msssim'] * self.msssim(image, target)\n            loss_components['msssim'] = loss_msssim\n\n    loss = torch.sum(torch.stack(list(loss_components.values())), dim = 0)\n\n    if self.return_components:\n        return loss, loss_components\n    return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss.__init__","title":"__init__(target_image, target_depth, blur_ratio=0.25, target_blur_size=10, number_of_planes=4, multiplier=1.0, scheme='defocus', base_loss_weights={'base_l2_loss': 1.0, 'loss_l2_mask': 1.0, 'loss_l2_cor': 1.0, 'base_l1_loss': 1.0, 'loss_l1_mask': 1.0, 'loss_l1_cor': 1.0}, additional_loss_weights={'cvvdp': 1.0}, reduction='mean', return_components=False, device=torch.device('cpu'))","text":"

Parameters:

  • target_image \u2013
                        Color target image [3 x m x n].\n
  • target_depth \u2013
                        Monochrome target depth, same resolution as target_image.\n
  • target_blur_size \u2013
                        Maximum target blur size.\n
  • blur_ratio \u2013
                        Blur ratio, a value between zero and one.\n
  • number_of_planes \u2013
                        Number of planes.\n
  • multiplier \u2013
                        Multiplier to multipy with targets.\n
  • scheme \u2013
                        The type of the loss, `naive` without defocus or `defocus` with defocus.\n
  • base_loss_weights \u2013
                        Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.\n
  • additional_loss_weights (dict, default: {'cvvdp': 1.0} ) \u2013
                        Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.\n
  • reduction \u2013
                        Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n
  • return_components \u2013
                        If True (False by default), returns the components of the loss as a dict.\n
  • device \u2013
                        Device to be used (e.g., cuda, cpu, opencl).\n
Source code in odak/learn/wave/loss.py
def __init__(self, target_image, target_depth, blur_ratio = 0.25, \n             target_blur_size = 10, number_of_planes = 4, multiplier = 1., scheme = 'defocus', \n             base_loss_weights = {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.},\n             additional_loss_weights = {'cvvdp': 1.}, reduction = 'mean', return_components = False, device = torch.device('cpu')):\n    \"\"\"\n    Parameters\n    ----------\n    target_image            : torch.tensor\n                                Color target image [3 x m x n].\n    target_depth            : torch.tensor\n                                Monochrome target depth, same resolution as target_image.\n    target_blur_size        : int\n                                Maximum target blur size.\n    blur_ratio              : float\n                                Blur ratio, a value between zero and one.\n    number_of_planes        : int\n                                Number of planes.\n    multiplier              : float\n                                Multiplier to multipy with targets.\n    scheme                  : str\n                                The type of the loss, `naive` without defocus or `defocus` with defocus.\n    base_loss_weights       : list\n                                Weights of the base loss functions. Default is {'base_l2_loss': 1., 'loss_l2_mask': 1., 'loss_l2_cor': 1., 'base_l1_loss': 1., 'loss_l1_mask': 1., 'loss_l1_cor': 1.}.\n    additional_loss_weights : dict\n                                Additional loss terms and their weights (e.g., {'cvvdp': 1.}). Supported loss terms are 'cvvdp', 'fvvdp', 'lpips', 'psnr', 'ssim', 'msssim'.\n    reduction               : str\n                                Reduction can either be 'mean', 'none' or 'sum'. For more see: https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss\n    return_components       : bool\n                                If True (False by default), returns the components of the loss as a dict.\n    device                  : torch.device\n                                Device to be used (e.g., cuda, cpu, opencl).\n    \"\"\"\n    self.device = device\n    self.target_image     = target_image.float().to(self.device)\n    self.target_depth     = target_depth.float().to(self.device)\n    self.target_blur_size = target_blur_size\n    if self.target_blur_size % 2 == 0:\n        self.target_blur_size += 1\n    self.number_of_planes = number_of_planes\n    self.multiplier       = multiplier\n    self.reduction        = reduction\n    if self.reduction == 'none' and len(list(additional_loss_weights.keys())) > 0:\n        logging.warning(\"Reduction cannot be 'none' for additional loss functions. Changing reduction to 'mean'.\")\n        self.reduction = 'mean'\n    self.blur_ratio       = blur_ratio\n    self.set_targets()\n    if scheme == 'defocus':\n        self.add_defocus_blur()\n    self.base_loss_weights = base_loss_weights\n    self.additional_loss_weights = additional_loss_weights\n    self.return_components = return_components\n    self.l1_loss_fn = torch.nn.L1Loss(reduction = self.reduction)\n    self.l2_loss_fn = torch.nn.MSELoss(reduction = self.reduction)\n    for key in self.additional_loss_weights.keys():\n        if key == 'cvvdp':\n            self.cvvdp = CVVDP()\n        if key == 'fvvdp':\n            self.fvvdp = FVVDP()\n        if key == 'lpips':\n            self.lpips = LPIPS()\n        if key == 'psnr':\n            self.psnr = PSNR()\n        if key == 'ssim':\n            self.ssim = SSIM()\n        if key == 'msssim':\n            self.msssim = MSSSIM()\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss.add_defocus_blur","title":"add_defocus_blur()","text":"

Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def add_defocus_blur(self):\n    \"\"\"\n    Internal function for adding defocus blur to the multiplane targets. Users can query the results with get_targets() within the same class.\n    \"\"\"\n    kernel_length = [self.target_blur_size, self.target_blur_size ]\n    for ch in range(self.target_image.shape[0]):\n        targets_cache = self.targets[:, ch].detach().clone()\n        target = torch.sum(targets_cache, axis = 0)\n        for i in range(self.number_of_planes):\n            defocus = torch.zeros_like(targets_cache[i])\n            for j in range(self.number_of_planes):\n                nsigma = [int(abs(i - j) * self.blur_ratio), int(abs(i -j) * self.blur_ratio)]\n                if torch.sum(targets_cache[j]) > 0:\n                    if i == j:\n                        nsigma = [0., 0.]\n                    kernel = generate_2d_gaussian(kernel_length, nsigma).to(self.device)\n                    kernel = kernel / torch.sum(kernel)\n                    kernel = kernel.unsqueeze(0).unsqueeze(0)\n                    target_current = target.detach().clone().unsqueeze(0).unsqueeze(0)\n                    defocus_plane = torch.nn.functional.conv2d(target_current, kernel, padding = 'same')\n                    defocus_plane = defocus_plane.view(defocus_plane.shape[-2], defocus_plane.shape[-1])\n                    defocus = defocus + defocus_plane * torch.abs(self.masks[j, ch])\n            self.targets[i, ch] = defocus\n    self.targets = self.targets.detach().clone() * self.multiplier\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss.get_targets","title":"get_targets()","text":"

Returns:

  • targets ( tensor ) \u2013

    Returns a copy of the targets.

  • target_depth ( tensor ) \u2013

    Returns a copy of the normalized quantized depth map.

Source code in odak/learn/wave/loss.py
def get_targets(self):\n    \"\"\"\n    Returns\n    -------\n    targets           : torch.tensor\n                        Returns a copy of the targets.\n    target_depth      : torch.tensor\n                        Returns a copy of the normalized quantized depth map.\n\n    \"\"\"\n    divider = self.number_of_planes - 1\n    if divider == 0:\n        divider = 1\n    return self.targets.detach().clone(), self.focus_target.detach().clone(), self.target_depth.detach().clone() / divider\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.perceptual_multiplane_loss.set_targets","title":"set_targets()","text":"

Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.

Source code in odak/learn/wave/loss.py
def set_targets(self):\n    \"\"\"\n    Internal function for slicing the depth into planes without considering defocus. Users can query the results with get_targets() within the same class.\n    \"\"\"\n    self.target_depth = self.target_depth * (self.number_of_planes - 1)\n    self.target_depth = torch.round(self.target_depth, decimals = 0)\n    self.targets      = torch.zeros(\n                                    self.number_of_planes,\n                                    self.target_image.shape[0],\n                                    self.target_image.shape[1],\n                                    self.target_image.shape[2],\n                                    requires_grad = False,\n                                    device = self.device\n                                   )\n    self.focus_target = torch.zeros_like(self.target_image, requires_grad = False)\n    self.masks        = torch.zeros_like(self.targets)\n    for i in range(self.number_of_planes):\n        for ch in range(self.target_image.shape[0]):\n            mask_zeros = torch.zeros_like(self.target_image[ch], dtype = torch.int)\n            mask_ones = torch.ones_like(self.target_image[ch], dtype = torch.int)\n            mask = torch.where(self.target_depth == i, mask_ones, mask_zeros)\n            new_target = self.target_image[ch] * mask\n            self.focus_target = self.focus_target + new_target.squeeze(0).squeeze(0).detach().clone()\n            self.targets[i, ch] = new_target.squeeze(0).squeeze(0)\n            self.masks[i, ch] = mask.detach().clone() \n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.phase_gradient","title":"phase_gradient","text":"

Bases: Module

The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude.

This implements a convolution of the phase with a kernel.

The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.

Source code in odak/learn/wave/loss.py
class phase_gradient(nn.Module):\n\n    \"\"\"\n    The class 'phase_gradient' provides a regularization function to measure the variation(Gradient or Laplace) of the phase of the complex amplitude. \n\n    This implements a convolution of the phase with a kernel.\n\n    The kernel is a simple 3 by 3 Laplacian kernel here, but you can also try other edge detection methods.\n    \"\"\"\n\n\n    def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device(\"cpu\")):\n        \"\"\"\n        Parameters\n        ----------\n        kernel                  : torch.tensor\n                                    Convolution filter kernel, 3 by 3 Laplacian kernel by default.\n        loss                    : torch.nn.Module\n                                    loss function, L2 Loss by default.\n        \"\"\"\n        super(phase_gradient, self).__init__()\n        self.device = device\n        self.loss = loss\n        if kernel == None:\n            self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8\n        else:\n            if len(kernel.shape) == 4:\n                self.kernel = kernel\n            else:\n                self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))\n        self.kernel = Variable(self.kernel.to(self.device))\n\n\n    def forward(self, phase):\n        \"\"\"\n        Calculates the phase gradient Loss.\n\n        Parameters\n        ----------\n        phase                  : torch.tensor\n                                    Phase of the complex amplitude.\n\n        Returns\n        -------\n\n        loss_value              : torch.tensor\n                                    The computed loss.\n        \"\"\"\n\n        if len(phase.shape) == 2:\n            phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))\n        edge_detect = self.functional_conv2d(phase)\n        loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))\n        return loss_value\n\n\n    def functional_conv2d(self, phase):\n        \"\"\"\n        Calculates the gradient of the phase.\n\n        Parameters\n        ----------\n        phase                  : torch.tensor\n                                    Phase of the complex amplitude.\n\n        Returns\n        -------\n\n        edge_detect              : torch.tensor\n                                    The computed phase gradient.\n        \"\"\"\n        edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)\n        return edge_detect\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.phase_gradient.__init__","title":"__init__(kernel=None, loss=nn.MSELoss(), device=torch.device('cpu'))","text":"

Parameters:

  • kernel \u2013
                        Convolution filter kernel, 3 by 3 Laplacian kernel by default.\n
  • loss \u2013
                        loss function, L2 Loss by default.\n
Source code in odak/learn/wave/loss.py
def __init__(self, kernel = None, loss = nn.MSELoss(), device = torch.device(\"cpu\")):\n    \"\"\"\n    Parameters\n    ----------\n    kernel                  : torch.tensor\n                                Convolution filter kernel, 3 by 3 Laplacian kernel by default.\n    loss                    : torch.nn.Module\n                                loss function, L2 Loss by default.\n    \"\"\"\n    super(phase_gradient, self).__init__()\n    self.device = device\n    self.loss = loss\n    if kernel == None:\n        self.kernel = torch.tensor([[[[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]]], dtype=torch.float32) / 8\n    else:\n        if len(kernel.shape) == 4:\n            self.kernel = kernel\n        else:\n            self.kernel = kernel.reshape((1, 1, kernel.shape[0], kernel.shape[1]))\n    self.kernel = Variable(self.kernel.to(self.device))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.phase_gradient.forward","title":"forward(phase)","text":"

Calculates the phase gradient Loss.

Parameters:

  • phase \u2013
                        Phase of the complex amplitude.\n

Returns:

  • loss_value ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, phase):\n    \"\"\"\n    Calculates the phase gradient Loss.\n\n    Parameters\n    ----------\n    phase                  : torch.tensor\n                                Phase of the complex amplitude.\n\n    Returns\n    -------\n\n    loss_value              : torch.tensor\n                                The computed loss.\n    \"\"\"\n\n    if len(phase.shape) == 2:\n        phase = phase.reshape((1, 1, phase.shape[0], phase.shape[1]))\n    edge_detect = self.functional_conv2d(phase)\n    loss_value = self.loss(edge_detect, torch.zeros_like(edge_detect))\n    return loss_value\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.phase_gradient.functional_conv2d","title":"functional_conv2d(phase)","text":"

Calculates the gradient of the phase.

Parameters:

  • phase \u2013
                        Phase of the complex amplitude.\n

Returns:

  • edge_detect ( tensor ) \u2013

    The computed phase gradient.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, phase):\n    \"\"\"\n    Calculates the gradient of the phase.\n\n    Parameters\n    ----------\n    phase                  : torch.tensor\n                                Phase of the complex amplitude.\n\n    Returns\n    -------\n\n    edge_detect              : torch.tensor\n                                The computed phase gradient.\n    \"\"\"\n    edge_detect = F.conv2d(phase, self.kernel, padding = self.kernel.shape[-1] // 2)\n    return edge_detect\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.speckle_contrast","title":"speckle_contrast","text":"

Bases: Module

The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.

We refer to the following paper:

Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0.

Source code in odak/learn/wave/loss.py
class speckle_contrast(nn.Module):\n\n    \"\"\"\n    The class 'speckle_contrast' provides a regularization function to measure the speckle contrast of the intensity of the complex amplitude using C=sigma/mean. Where C is the speckle contrast, mean and sigma are mean and standard deviation of the intensity.\n\n    We refer to the following paper:\n\n    Kim et al.(2020). Light source optimization for partially coherent holographic displays with consideration of speckle contrast, resolution, and depth of field. Scientific Reports. 10. 18832. 10.1038/s41598-020-75947-0. \n    \"\"\"\n\n\n    def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device(\"cpu\")):\n        \"\"\"\n        Parameters\n        ----------\n        kernel_size             : torch.tensor\n                                    Convolution filter kernel size, 11 by 11 average kernel by default.\n        step_size               : tuple\n                                    Convolution stride in height and width direction.\n        loss                    : torch.nn.Module\n                                    loss function, L2 Loss by default.\n        \"\"\"\n        super(speckle_contrast, self).__init__()\n        self.device = device\n        self.loss = loss\n        self.step_size = step_size\n        self.kernel_size = kernel_size\n        self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)\n        self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))\n\n\n    def forward(self, intensity):\n        \"\"\"\n        Calculates the speckle contrast Loss.\n\n        Parameters\n        ----------\n        intensity               : torch.tensor\n                                    intensity of the complex amplitude.\n\n        Returns\n        -------\n\n        loss_value              : torch.tensor\n                                    The computed loss.\n        \"\"\"\n\n        if len(intensity.shape) == 2:\n            intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))\n        Speckle_C = self.functional_conv2d(intensity)\n        loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))\n        return loss_value\n\n\n    def functional_conv2d(self, intensity):\n        \"\"\"\n        Calculates the speckle contrast of the intensity.\n\n        Parameters\n        ----------\n        intensity                : torch.tensor\n                                    Intensity of the complex field.\n\n        Returns\n        -------\n\n        Speckle_C               : torch.tensor\n                                    The computed speckle contrast.\n        \"\"\"\n        mean = F.conv2d(intensity, self.kernel, stride = self.step_size)\n        var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))\n        Speckle_C = var / mean\n        return Speckle_C\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.speckle_contrast.__init__","title":"__init__(kernel_size=11, step_size=(1, 1), loss=nn.MSELoss(), device=torch.device('cpu'))","text":"

Parameters:

  • kernel_size \u2013
                        Convolution filter kernel size, 11 by 11 average kernel by default.\n
  • step_size \u2013
                        Convolution stride in height and width direction.\n
  • loss \u2013
                        loss function, L2 Loss by default.\n
Source code in odak/learn/wave/loss.py
def __init__(self, kernel_size = 11, step_size = (1, 1), loss = nn.MSELoss(), device=torch.device(\"cpu\")):\n    \"\"\"\n    Parameters\n    ----------\n    kernel_size             : torch.tensor\n                                Convolution filter kernel size, 11 by 11 average kernel by default.\n    step_size               : tuple\n                                Convolution stride in height and width direction.\n    loss                    : torch.nn.Module\n                                loss function, L2 Loss by default.\n    \"\"\"\n    super(speckle_contrast, self).__init__()\n    self.device = device\n    self.loss = loss\n    self.step_size = step_size\n    self.kernel_size = kernel_size\n    self.kernel = torch.ones((1, 1, self.kernel_size, self.kernel_size)) / (self.kernel_size ** 2)\n    self.kernel = Variable(self.kernel.type(torch.FloatTensor).to(self.device))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.speckle_contrast.forward","title":"forward(intensity)","text":"

Calculates the speckle contrast Loss.

Parameters:

  • intensity \u2013
                        intensity of the complex amplitude.\n

Returns:

  • loss_value ( tensor ) \u2013

    The computed loss.

Source code in odak/learn/wave/loss.py
def forward(self, intensity):\n    \"\"\"\n    Calculates the speckle contrast Loss.\n\n    Parameters\n    ----------\n    intensity               : torch.tensor\n                                intensity of the complex amplitude.\n\n    Returns\n    -------\n\n    loss_value              : torch.tensor\n                                The computed loss.\n    \"\"\"\n\n    if len(intensity.shape) == 2:\n        intensity = intensity.reshape((1, 1, intensity.shape[0], intensity.shape[1]))\n    Speckle_C = self.functional_conv2d(intensity)\n    loss_value = self.loss(Speckle_C, torch.zeros_like(Speckle_C))\n    return loss_value\n
"},{"location":"odak/learn_wave/#odak.learn.wave.loss.speckle_contrast.functional_conv2d","title":"functional_conv2d(intensity)","text":"

Calculates the speckle contrast of the intensity.

Parameters:

  • intensity \u2013
                        Intensity of the complex field.\n

Returns:

  • Speckle_C ( tensor ) \u2013

    The computed speckle contrast.

Source code in odak/learn/wave/loss.py
def functional_conv2d(self, intensity):\n    \"\"\"\n    Calculates the speckle contrast of the intensity.\n\n    Parameters\n    ----------\n    intensity                : torch.tensor\n                                Intensity of the complex field.\n\n    Returns\n    -------\n\n    Speckle_C               : torch.tensor\n                                The computed speckle contrast.\n    \"\"\"\n    mean = F.conv2d(intensity, self.kernel, stride = self.step_size)\n    var = torch.sqrt(F.conv2d(torch.pow(intensity, 2), self.kernel, stride = self.step_size) - torch.pow(mean, 2))\n    Speckle_C = var / mean\n    return Speckle_C\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.channel_gate","title":"channel_gate","text":"

Bases: Module

Channel attention module with various pooling strategies. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class channel_gate(torch.nn.Module):\n    \"\"\"\n    Channel attention module with various pooling strategies.\n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max']\n                ):\n        \"\"\"\n        Initializes the channel gate module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the intermediate layer.\n        pool_types      : list\n                          List of pooling operations to apply.\n        \"\"\"\n        super().__init__()\n        self.gate_channels = gate_channels\n        hidden_channels = gate_channels // reduction_ratio\n        if hidden_channels == 0:\n            hidden_channels = 1\n        self.mlp = torch.nn.Sequential(\n                                       convolutional_block_attention.Flatten(),\n                                       torch.nn.Linear(gate_channels, hidden_channels),\n                                       torch.nn.ReLU(),\n                                       torch.nn.Linear(hidden_channels, gate_channels)\n                                      )\n        self.pool_types = pool_types\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the ChannelGate module.\n\n        Applies channel-wise attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the ChannelGate module.\n\n        Returns\n        -------\n        output       : torch.tensor\n                       Output tensor after applying channel attention.\n        \"\"\"\n        channel_att_sum = None\n        for pool_type in self.pool_types:\n            if pool_type == 'avg':\n                pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            elif pool_type == 'max':\n                pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n            channel_att_raw = self.mlp(pool)\n            channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n        output = x * scale\n        return output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.channel_gate.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'])","text":"

Initializes the channel gate module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the intermediate layer.\n
  • pool_types \u2013
              List of pooling operations to apply.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max']\n            ):\n    \"\"\"\n    Initializes the channel gate module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the intermediate layer.\n    pool_types      : list\n                      List of pooling operations to apply.\n    \"\"\"\n    super().__init__()\n    self.gate_channels = gate_channels\n    hidden_channels = gate_channels // reduction_ratio\n    if hidden_channels == 0:\n        hidden_channels = 1\n    self.mlp = torch.nn.Sequential(\n                                   convolutional_block_attention.Flatten(),\n                                   torch.nn.Linear(gate_channels, hidden_channels),\n                                   torch.nn.ReLU(),\n                                   torch.nn.Linear(hidden_channels, gate_channels)\n                                  )\n    self.pool_types = pool_types\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.channel_gate.forward","title":"forward(x)","text":"

Forward pass of the ChannelGate module.

Applies channel-wise attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the ChannelGate module.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor after applying channel attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the ChannelGate module.\n\n    Applies channel-wise attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the ChannelGate module.\n\n    Returns\n    -------\n    output       : torch.tensor\n                   Output tensor after applying channel attention.\n    \"\"\"\n    channel_att_sum = None\n    for pool_type in self.pool_types:\n        if pool_type == 'avg':\n            pool = torch.nn.functional.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        elif pool_type == 'max':\n            pool = torch.nn.functional.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))\n        channel_att_raw = self.mlp(pool)\n        channel_att_sum = channel_att_raw if channel_att_sum is None else channel_att_sum + channel_att_raw\n    scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)\n    output = x * scale\n    return output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolution_layer","title":"convolution_layer","text":"

Bases: Module

A convolution layer.

Source code in odak/learn/models/components.py
class convolution_layer(torch.nn.Module):\n    \"\"\"\n    A convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 bias = False,\n                 stride = 1,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        layers = [\n            torch.nn.Conv2d(\n                            input_channels,\n                            output_channels,\n                            kernel_size = kernel_size,\n                            stride = stride,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           )\n        ]\n        if normalization:\n            layers.append(torch.nn.BatchNorm2d(output_channels))\n        if activation:\n            layers.append(activation)\n        self.model = torch.nn.Sequential(*layers)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolution_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, bias=False, stride=1, normalization=True, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             bias = False,\n             stride = 1,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    layers = [\n        torch.nn.Conv2d(\n                        input_channels,\n                        output_channels,\n                        kernel_size = kernel_size,\n                        stride = stride,\n                        padding = kernel_size // 2,\n                        bias = bias\n                       )\n    ]\n    if normalization:\n        layers.append(torch.nn.BatchNorm2d(output_channels))\n    if activation:\n        layers.append(activation)\n    self.model = torch.nn.Sequential(*layers)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolution_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolutional_block_attention","title":"convolutional_block_attention","text":"

Bases: Module

Convolutional Block Attention Module (CBAM) class. This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).

Source code in odak/learn/models/components.py
class convolutional_block_attention(torch.nn.Module):\n    \"\"\"\n    Convolutional Block Attention Module (CBAM) class. \n    This class is heavily inspired https://github.com/Jongchan/attention-module/commit/e4ee180f1335c09db14d39a65d97c8ca3d1f7b16 (MIT License).\n    \"\"\"\n    def __init__(\n                 self, \n                 gate_channels, \n                 reduction_ratio = 16, \n                 pool_types = ['avg', 'max'], \n                 no_spatial = False\n                ):\n        \"\"\"\n        Initializes the convolutional block attention module.\n\n        Parameters\n        ----------\n        gate_channels   : int\n                          Number of channels of the input feature map.\n        reduction_ratio : int\n                          Reduction ratio for the channel attention.\n        pool_types      : list\n                          List of pooling operations to apply for channel attention.\n        no_spatial      : bool\n                          If True, spatial attention is not applied.\n        \"\"\"\n        super(convolutional_block_attention, self).__init__()\n        self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n        self.no_spatial = no_spatial\n        if not no_spatial:\n            self.spatial_gate = spatial_gate()\n\n\n    class Flatten(torch.nn.Module):\n        \"\"\"\n        Flattens the input tensor to a 2D matrix.\n        \"\"\"\n        def forward(self, x):\n            return x.view(x.size(0), -1)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the convolutional block attention module.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the CBAM module.\n\n        Returns\n        -------\n        x_out        : torch.tensor\n                       Output tensor after applying channel and spatial attention.\n        \"\"\"\n        x_out = self.channel_gate(x)\n        if not self.no_spatial:\n            x_out = self.spatial_gate(x_out)\n        return x_out\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolutional_block_attention.Flatten","title":"Flatten","text":"

Bases: Module

Flattens the input tensor to a 2D matrix.

Source code in odak/learn/models/components.py
class Flatten(torch.nn.Module):\n    \"\"\"\n    Flattens the input tensor to a 2D matrix.\n    \"\"\"\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolutional_block_attention.__init__","title":"__init__(gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False)","text":"

Initializes the convolutional block attention module.

Parameters:

  • gate_channels \u2013
              Number of channels of the input feature map.\n
  • reduction_ratio (int, default: 16 ) \u2013
              Reduction ratio for the channel attention.\n
  • pool_types \u2013
              List of pooling operations to apply for channel attention.\n
  • no_spatial \u2013
              If True, spatial attention is not applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self, \n             gate_channels, \n             reduction_ratio = 16, \n             pool_types = ['avg', 'max'], \n             no_spatial = False\n            ):\n    \"\"\"\n    Initializes the convolutional block attention module.\n\n    Parameters\n    ----------\n    gate_channels   : int\n                      Number of channels of the input feature map.\n    reduction_ratio : int\n                      Reduction ratio for the channel attention.\n    pool_types      : list\n                      List of pooling operations to apply for channel attention.\n    no_spatial      : bool\n                      If True, spatial attention is not applied.\n    \"\"\"\n    super(convolutional_block_attention, self).__init__()\n    self.channel_gate = channel_gate(gate_channels, reduction_ratio, pool_types)\n    self.no_spatial = no_spatial\n    if not no_spatial:\n        self.spatial_gate = spatial_gate()\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.convolutional_block_attention.forward","title":"forward(x)","text":"

Forward pass of the convolutional block attention module.

Parameters:

  • x \u2013
           Input tensor to the CBAM module.\n

Returns:

  • x_out ( tensor ) \u2013

    Output tensor after applying channel and spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the convolutional block attention module.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the CBAM module.\n\n    Returns\n    -------\n    x_out        : torch.tensor\n                   Output tensor after applying channel and spatial attention.\n    \"\"\"\n    x_out = self.channel_gate(x)\n    if not self.no_spatial:\n        x_out = self.spatial_gate(x_out)\n    return x_out\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.double_convolution","title":"double_convolution","text":"

Bases: Module

A double convolution layer.

Source code in odak/learn/models/components.py
class double_convolution(torch.nn.Module):\n    \"\"\"\n    A double convolution layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = None,\n                 output_channels = 2,\n                 kernel_size = 3, \n                 bias = False,\n                 normalization = True,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        Double convolution model.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of channels in the hidden layer between two convolutions.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        if isinstance(mid_channels, type(None)):\n            mid_channels = output_channels\n        self.activation = activation\n        self.model = torch.nn.Sequential(\n                                         convolution_layer(\n                                                           input_channels = input_channels,\n                                                           output_channels = mid_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          ),\n                                         convolution_layer(\n                                                           input_channels = mid_channels,\n                                                           output_channels = output_channels,\n                                                           kernel_size = kernel_size,\n                                                           bias = bias,\n                                                           normalization = normalization,\n                                                           activation = self.activation\n                                                          )\n                                        )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.model(x)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.double_convolution.__init__","title":"__init__(input_channels=2, mid_channels=None, output_channels=2, kernel_size=3, bias=False, normalization=True, activation=torch.nn.ReLU())","text":"

Double convolution model.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of channels in the hidden layer between two convolutions.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = None,\n             output_channels = 2,\n             kernel_size = 3, \n             bias = False,\n             normalization = True,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    Double convolution model.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of channels in the hidden layer between two convolutions.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    if isinstance(mid_channels, type(None)):\n        mid_channels = output_channels\n    self.activation = activation\n    self.model = torch.nn.Sequential(\n                                     convolution_layer(\n                                                       input_channels = input_channels,\n                                                       output_channels = mid_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      ),\n                                     convolution_layer(\n                                                       input_channels = mid_channels,\n                                                       output_channels = output_channels,\n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       normalization = normalization,\n                                                       activation = self.activation\n                                                      )\n                                    )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.double_convolution.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.model(x)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.downsample_layer","title":"downsample_layer","text":"

Bases: Module

A downscaling component followed by a double convolution.

Source code in odak/learn/models/components.py
class downsample_layer(torch.nn.Module):\n    \"\"\"\n    A downscaling component followed by a double convolution.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.maxpool_conv = torch.nn.Sequential(\n                                                torch.nn.MaxPool2d(2),\n                                                double_convolution(\n                                                                   input_channels = input_channels,\n                                                                   mid_channels = output_channels,\n                                                                   output_channels = output_channels,\n                                                                   kernel_size = kernel_size,\n                                                                   bias = bias,\n                                                                   activation = activation\n                                                                  )\n                                               )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x              : torch.tensor\n                         First input data.\n\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        result = self.maxpool_conv(x)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.downsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.maxpool_conv = torch.nn.Sequential(\n                                            torch.nn.MaxPool2d(2),\n                                            double_convolution(\n                                                               input_channels = input_channels,\n                                                               mid_channels = output_channels,\n                                                               output_channels = output_channels,\n                                                               kernel_size = kernel_size,\n                                                               bias = bias,\n                                                               activation = activation\n                                                              )\n                                           )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.downsample_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
             First input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x              : torch.tensor\n                     First input data.\n\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    result = self.maxpool_conv(x)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.focal_surface_light_propagation","title":"focal_surface_light_propagation","text":"

Bases: Module

focal_surface_light_propagation model.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit}. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.

Source code in odak/learn/wave/models.py
class focal_surface_light_propagation(torch.nn.Module):\n    \"\"\"\n    focal_surface_light_propagation model.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit}. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24),December,2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 depth = 3,\n                 dimensions = 8,\n                 input_channels = 6,\n                 out_channels = 6,\n                 kernel_size = 3,\n                 bias = True,\n                 device = torch.device('cpu'),\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes the focal surface light propagation model.\n\n        Parameters\n        ----------\n        depth             : int\n                            Number of downsampling and upsampling layers.\n        dimensions        : int\n                            Number of dimensions/features in the model.\n        input_channels    : int\n                            Number of input channels.\n        out_channels      : int\n                            Number of output channels.\n        kernel_size       : int\n                            Size of the convolution kernel.\n        bias              : bool\n                            If True, allows convolutional layers to learn a bias term.\n        device            : torch.device\n                            Default device is CPU.\n        activation        : torch.nn.Module\n                            Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.device = device\n        self.sv_kernel_generation = spatially_varying_kernel_generation_model(\n            depth = depth,\n            dimensions = dimensions,\n            input_channels = input_channels + 1,  # +1 to account for an extra channel\n            kernel_size = kernel_size,\n            bias = bias,\n            activation = activation\n        )\n        self.light_propagation = spatially_adaptive_unet(\n            depth = depth,\n            dimensions = dimensions,\n            input_channels = input_channels,\n            out_channels = out_channels,\n            kernel_size = kernel_size,\n            bias = bias,\n            activation = activation\n        )\n\n\n    def forward(self, focal_surface, phase_only_hologram):\n        \"\"\"\n        Forward pass through the model.\n\n        Parameters\n        ----------\n        focal_surface         : torch.Tensor\n                                Input focal surface.\n        phase_only_hologram   : torch.Tensor\n                                Input phase-only hologram.\n\n        Returns\n        ----------\n        result                : torch.Tensor\n                                Output tensor after light propagation.\n        \"\"\"\n        input_field = self.generate_input_field(phase_only_hologram)\n        sv_kernel = self.sv_kernel_generation(focal_surface, input_field)\n        output_field = self.light_propagation(sv_kernel, input_field)\n        final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])\n        result = calculate_amplitude(final) ** 2\n        return result\n\n\n    def generate_input_field(self, phase_only_hologram):\n        \"\"\"\n        Generates an input field by combining the real and imaginary parts.\n\n        Parameters\n        ----------\n        phase_only_hologram   : torch.Tensor\n                                Input phase-only hologram.\n\n        Returns\n        ----------\n        input_field           : torch.Tensor\n                                Concatenated real and imaginary parts of the complex field.\n        \"\"\"\n        [b, c, h, w] = phase_only_hologram.size()\n        input_phase = phase_only_hologram * 2 * np.pi\n        hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)\n        field = generate_complex_field(hologram_amplitude, input_phase)\n        input_field = torch.cat((field.real, field.imag), dim = 1)\n        return input_field\n\n\n    def load_weights(self, weight_filename, key_mapping_filename):\n        \"\"\"\n        Function to load weights for this multi-layer perceptron from a file.\n\n        Parameters\n        ----------\n        weight_filename      : str\n                               Path to the old model's weight file.\n        key_mapping_filename : str\n                               Path to the JSON file containing the key mappings.\n        \"\"\"\n        # Load old model weights\n        old_model_weights = torch.load(weight_filename, map_location = self.device)\n\n        # Load key mappings from JSON file\n        with open(key_mapping_filename, 'r') as json_file:\n            key_mappings = json.load(json_file)\n\n        # Extract the key mappings for sv_kernel_generation and light_prop\n        sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']\n        light_prop_key_mapping = key_mappings['light_prop_key_mapping']\n\n        # Initialize new state dicts\n        sv_kernel_generation_new_state_dict = {}\n        light_prop_new_state_dict = {}\n\n        # Map and load sv_kernel_generation_model weights\n        for old_key, value in old_model_weights.items():\n            if old_key in sv_kernel_generation_key_mapping:\n                # Map the old key to the new key\n                new_key = sv_kernel_generation_key_mapping[old_key]\n                sv_kernel_generation_new_state_dict[new_key] = value\n\n        self.sv_kernel_generation.to(self.device)\n        self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)\n\n        # Map and load light_prop model weights\n        for old_key, value in old_model_weights.items():\n            if old_key in light_prop_key_mapping:\n                # Map the old key to the new key\n                new_key = light_prop_key_mapping[old_key]\n                light_prop_new_state_dict[new_key] = value\n        self.light_propagation.to(self.device)\n        self.light_propagation.load_state_dict(light_prop_new_state_dict)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.focal_surface_light_propagation.__init__","title":"__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, device=torch.device('cpu'), activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes the focal surface light propagation model.

Parameters:

  • depth \u2013
                Number of downsampling and upsampling layers.\n
  • dimensions \u2013
                Number of dimensions/features in the model.\n
  • input_channels \u2013
                Number of input channels.\n
  • out_channels \u2013
                Number of output channels.\n
  • kernel_size \u2013
                Size of the convolution kernel.\n
  • bias \u2013
                If True, allows convolutional layers to learn a bias term.\n
  • device \u2013
                Default device is CPU.\n
  • activation \u2013
                Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/wave/models.py
def __init__(\n             self,\n             depth = 3,\n             dimensions = 8,\n             input_channels = 6,\n             out_channels = 6,\n             kernel_size = 3,\n             bias = True,\n             device = torch.device('cpu'),\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes the focal surface light propagation model.\n\n    Parameters\n    ----------\n    depth             : int\n                        Number of downsampling and upsampling layers.\n    dimensions        : int\n                        Number of dimensions/features in the model.\n    input_channels    : int\n                        Number of input channels.\n    out_channels      : int\n                        Number of output channels.\n    kernel_size       : int\n                        Size of the convolution kernel.\n    bias              : bool\n                        If True, allows convolutional layers to learn a bias term.\n    device            : torch.device\n                        Default device is CPU.\n    activation        : torch.nn.Module\n                        Activation function (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.device = device\n    self.sv_kernel_generation = spatially_varying_kernel_generation_model(\n        depth = depth,\n        dimensions = dimensions,\n        input_channels = input_channels + 1,  # +1 to account for an extra channel\n        kernel_size = kernel_size,\n        bias = bias,\n        activation = activation\n    )\n    self.light_propagation = spatially_adaptive_unet(\n        depth = depth,\n        dimensions = dimensions,\n        input_channels = input_channels,\n        out_channels = out_channels,\n        kernel_size = kernel_size,\n        bias = bias,\n        activation = activation\n    )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.focal_surface_light_propagation.forward","title":"forward(focal_surface, phase_only_hologram)","text":"

Forward pass through the model.

Parameters:

  • focal_surface \u2013
                    Input focal surface.\n
  • phase_only_hologram \u2013
                    Input phase-only hologram.\n

Returns:

  • result ( Tensor ) \u2013

    Output tensor after light propagation.

Source code in odak/learn/wave/models.py
def forward(self, focal_surface, phase_only_hologram):\n    \"\"\"\n    Forward pass through the model.\n\n    Parameters\n    ----------\n    focal_surface         : torch.Tensor\n                            Input focal surface.\n    phase_only_hologram   : torch.Tensor\n                            Input phase-only hologram.\n\n    Returns\n    ----------\n    result                : torch.Tensor\n                            Output tensor after light propagation.\n    \"\"\"\n    input_field = self.generate_input_field(phase_only_hologram)\n    sv_kernel = self.sv_kernel_generation(focal_surface, input_field)\n    output_field = self.light_propagation(sv_kernel, input_field)\n    final = (output_field[:, 0:3, :, :] + 1j * output_field[:, 3:6, :, :])\n    result = calculate_amplitude(final) ** 2\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.focal_surface_light_propagation.generate_input_field","title":"generate_input_field(phase_only_hologram)","text":"

Generates an input field by combining the real and imaginary parts.

Parameters:

  • phase_only_hologram \u2013
                    Input phase-only hologram.\n

Returns:

  • input_field ( Tensor ) \u2013

    Concatenated real and imaginary parts of the complex field.

Source code in odak/learn/wave/models.py
def generate_input_field(self, phase_only_hologram):\n    \"\"\"\n    Generates an input field by combining the real and imaginary parts.\n\n    Parameters\n    ----------\n    phase_only_hologram   : torch.Tensor\n                            Input phase-only hologram.\n\n    Returns\n    ----------\n    input_field           : torch.Tensor\n                            Concatenated real and imaginary parts of the complex field.\n    \"\"\"\n    [b, c, h, w] = phase_only_hologram.size()\n    input_phase = phase_only_hologram * 2 * np.pi\n    hologram_amplitude = torch.ones(b, c, h, w, requires_grad = False).to(self.device)\n    field = generate_complex_field(hologram_amplitude, input_phase)\n    input_field = torch.cat((field.real, field.imag), dim = 1)\n    return input_field\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.focal_surface_light_propagation.load_weights","title":"load_weights(weight_filename, key_mapping_filename)","text":"

Function to load weights for this multi-layer perceptron from a file.

Parameters:

  • weight_filename \u2013
                   Path to the old model's weight file.\n
  • key_mapping_filename (str) \u2013
                   Path to the JSON file containing the key mappings.\n
Source code in odak/learn/wave/models.py
def load_weights(self, weight_filename, key_mapping_filename):\n    \"\"\"\n    Function to load weights for this multi-layer perceptron from a file.\n\n    Parameters\n    ----------\n    weight_filename      : str\n                           Path to the old model's weight file.\n    key_mapping_filename : str\n                           Path to the JSON file containing the key mappings.\n    \"\"\"\n    # Load old model weights\n    old_model_weights = torch.load(weight_filename, map_location = self.device)\n\n    # Load key mappings from JSON file\n    with open(key_mapping_filename, 'r') as json_file:\n        key_mappings = json.load(json_file)\n\n    # Extract the key mappings for sv_kernel_generation and light_prop\n    sv_kernel_generation_key_mapping = key_mappings['sv_kernel_generation_key_mapping']\n    light_prop_key_mapping = key_mappings['light_prop_key_mapping']\n\n    # Initialize new state dicts\n    sv_kernel_generation_new_state_dict = {}\n    light_prop_new_state_dict = {}\n\n    # Map and load sv_kernel_generation_model weights\n    for old_key, value in old_model_weights.items():\n        if old_key in sv_kernel_generation_key_mapping:\n            # Map the old key to the new key\n            new_key = sv_kernel_generation_key_mapping[old_key]\n            sv_kernel_generation_new_state_dict[new_key] = value\n\n    self.sv_kernel_generation.to(self.device)\n    self.sv_kernel_generation.load_state_dict(sv_kernel_generation_new_state_dict)\n\n    # Map and load light_prop model weights\n    for old_key, value in old_model_weights.items():\n        if old_key in light_prop_key_mapping:\n            # Map the old key to the new key\n            new_key = light_prop_key_mapping[old_key]\n            light_prop_new_state_dict[new_key] = value\n    self.light_propagation.to(self.device)\n    self.light_propagation.load_state_dict(light_prop_new_state_dict)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_feature_module","title":"global_feature_module","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies them to another input tensor via learned transformations.

Source code in odak/learn/models/components.py
class global_feature_module(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies them to another input tensor via learned transformations.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 mid_channels,\n                 output_channels,\n                 kernel_size,\n                 bias = False,\n                 normalization = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels  : int\n                          Number of mid channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        normalization   : bool\n                          If True, adds a Batch Normalization layer after the convolutional layer.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.transformations_1 = global_transformations(input_channels, output_channels)\n        self.global_features_1 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.global_features_2 = double_convolution(\n                                                    input_channels = input_channels,\n                                                    mid_channels = mid_channels,\n                                                    output_channels = output_channels,\n                                                    kernel_size = kernel_size,\n                                                    bias = bias,\n                                                    normalization = normalization,\n                                                    activation = activation\n                                                   )\n        self.transformations_2 = global_transformations(input_channels, output_channels)\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        global_tensor_1 = self.transformations_1(x1, x2)\n        y1 = self.global_features_1(global_tensor_1)\n        y2 = self.global_features_2(y1)\n        global_tensor_2 = self.transformations_2(y1, y2)\n        return global_tensor_2\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_feature_module.__init__","title":"__init__(input_channels, mid_channels, output_channels, kernel_size, bias=False, normalization=False, activation=torch.nn.ReLU())","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of mid channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • normalization \u2013
              If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             mid_channels,\n             output_channels,\n             kernel_size,\n             bias = False,\n             normalization = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels  : int\n                      Number of mid channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    normalization   : bool\n                      If True, adds a Batch Normalization layer after the convolutional layer.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.transformations_1 = global_transformations(input_channels, output_channels)\n    self.global_features_1 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.global_features_2 = double_convolution(\n                                                input_channels = input_channels,\n                                                mid_channels = mid_channels,\n                                                output_channels = output_channels,\n                                                kernel_size = kernel_size,\n                                                bias = bias,\n                                                normalization = normalization,\n                                                activation = activation\n                                               )\n    self.transformations_2 = global_transformations(input_channels, output_channels)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_feature_module.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    global_tensor_1 = self.transformations_1(x1, x2)\n    y1 = self.global_features_1(global_tensor_1)\n    y2 = self.global_features_2(y1)\n    global_tensor_2 = self.transformations_2(y1, y2)\n    return global_tensor_2\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_transformations","title":"global_transformations","text":"

Bases: Module

A global feature layer that processes global features from input channels and applies learned transformations to another input tensor.

This implementation is adapted from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Reference: J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"

Source code in odak/learn/models/components.py
class global_transformations(torch.nn.Module):\n    \"\"\"\n    A global feature layer that processes global features from input channels and\n    applies learned transformations to another input tensor.\n\n    This implementation is adapted from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Reference:\n    J. Huang, P. Zhu, M. Geng et al. \"Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels\n                ):\n        \"\"\"\n        A global feature layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        \"\"\"\n        super().__init__()\n        self.global_feature_1 = torch.nn.Sequential(\n            torch.nn.Linear(input_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True),\n        )\n        self.global_feature_2 = torch.nn.Sequential(\n            torch.nn.Linear(output_channels, output_channels),\n            torch.nn.LeakyReLU(0.2, inplace = True)\n        )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.\n        \"\"\"\n        y = torch.mean(x2, dim = (2, 3))\n        y1 = self.global_feature_1(y)\n        y2 = self.global_feature_2(y1)\n        y1 = y1.unsqueeze(2).unsqueeze(3)\n        y2 = y2.unsqueeze(2).unsqueeze(3)\n        result = x1 * y1 + y2\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_transformations.__init__","title":"__init__(input_channels, output_channels)","text":"

A global feature layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels\n            ):\n    \"\"\"\n    A global feature layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    \"\"\"\n    super().__init__()\n    self.global_feature_1 = torch.nn.Sequential(\n        torch.nn.Linear(input_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True),\n    )\n    self.global_feature_2 = torch.nn.Sequential(\n        torch.nn.Linear(output_channels, output_channels),\n        torch.nn.LeakyReLU(0.2, inplace = True)\n    )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.global_transformations.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.\n    \"\"\"\n    y = torch.mean(x2, dim = (2, 3))\n    y1 = self.global_feature_1(y)\n    y2 = self.global_feature_2(y1)\n    y1 = y1.unsqueeze(2).unsqueeze(3)\n    y2 = y2.unsqueeze(2).unsqueeze(3)\n    result = x1 * y1 + y2\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo","title":"holobeam_multiholo","text":"

Bases: Module

The learned holography model used in the paper, Ak\u015fit, Kaan, and Yuta Itoh. \"HoloBeam: Paper-Thin Near-Eye Displays.\" In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.

Parameters:

  • n_input \u2013
                Number of channels in the input.\n
  • n_hidden \u2013
                Number of channels in the hidden layers.\n
  • n_output \u2013
                Number of channels in the output layer.\n
  • device \u2013
                Default device is CPU.\n
  • reduction \u2013
                Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.\n
Source code in odak/learn/wave/models.py
class holobeam_multiholo(torch.nn.Module):\n    \"\"\"\n    The learned holography model used in the paper, Ak\u015fit, Kaan, and Yuta Itoh. \"HoloBeam: Paper-Thin Near-Eye Displays.\" In 2023 IEEE Conference Virtual Reality and 3D User Interfaces (VR), pp. 581-591. IEEE, 2023.\n\n\n    Parameters\n    ----------\n    n_input           : int\n                        Number of channels in the input.\n    n_hidden          : int\n                        Number of channels in the hidden layers.\n    n_output          : int\n                        Number of channels in the output layer.\n    device            : torch.device\n                        Default device is CPU.\n    reduction         : str\n                        Reduction used for torch.nn.MSELoss and torch.nn.L1Loss. The default is 'sum'.\n    \"\"\"\n    def __init__(\n                 self,\n                 n_input = 1,\n                 n_hidden = 16,\n                 n_output = 2,\n                 device = torch.device('cpu'),\n                 reduction = 'sum'\n                ):\n        super(holobeam_multiholo, self).__init__()\n        torch.random.seed()\n        self.device = device\n        self.reduction = reduction\n        self.l2 = torch.nn.MSELoss(reduction = self.reduction)\n        self.l1 = torch.nn.L1Loss(reduction = self.reduction)\n        self.n_input = n_input\n        self.n_hidden = n_hidden\n        self.n_output = n_output\n        self.network = unet(\n                            dimensions = self.n_hidden,\n                            input_channels = self.n_input,\n                            output_channels = self.n_output\n                           ).to(self.device)\n\n\n    def forward(self, x, test = False):\n        \"\"\"\n        Internal function representing the forward model.\n        \"\"\"\n        if test:\n            torch.no_grad()\n        y = self.network.forward(x) \n        phase_low = y[:, 0].unsqueeze(1)\n        phase_high = y[:, 1].unsqueeze(1)\n        phase_only = torch.zeros_like(phase_low)\n        phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]\n        phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]\n        phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]\n        phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]\n        return phase_only\n\n\n    def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):\n        \"\"\"\n        Internal function for evaluating.\n        \"\"\"\n        loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)\n        return loss\n\n\n    def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):\n        \"\"\"\n        Function to train the weights of the multi layer perceptron.\n\n        Parameters\n        ----------\n        dataloader       : torch.utils.data.DataLoader\n                           Data loader.\n        number_of_epochs : int\n                           Number of epochs.\n        learning_rate    : float\n                           Learning rate of the optimizer.\n        directory        : str\n                           Output directory.\n        save_at_every    : int\n                           Save the model at every given epoch count.\n        \"\"\"\n        t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)\n        self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)\n        for i in t_epoch:\n            epoch_loss = 0.\n            t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)\n            for j, data in enumerate(t_data):\n                self.optimizer.zero_grad()\n                images, holograms = data\n                estimates = self.forward(images)\n                loss = self.evaluate(estimates, holograms)\n                loss.backward(retain_graph=True)\n                self.optimizer.step()\n                description = 'Loss:{:.4f}'.format(loss.item())\n                t_data.set_description(description)\n                epoch_loss += float(loss.item()) / dataloader.__len__()\n            description = 'Epoch Loss:{:.4f}'.format(epoch_loss)\n            t_epoch.set_description(description)\n            if i % save_at_every == 0:\n                self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))\n        self.save_weights(filename='{}/weights.pt'.format(directory))\n        print(description)\n\n\n    def save_weights(self, filename = './weights.pt'):\n        \"\"\"\n        Function to save the current weights of the multi layer perceptron to a file.\n        Parameters\n        ----------\n        filename        : str\n                          Filename.\n        \"\"\"\n        torch.save(self.network.state_dict(), os.path.expanduser(filename))\n\n\n    def load_weights(self, filename = './weights.pt'):\n        \"\"\"\n        Function to load weights for this multi layer perceptron from a file.\n        Parameters\n        ----------\n        filename        : str\n                          Filename.\n        \"\"\"\n        self.network.load_state_dict(torch.load(os.path.expanduser(filename)))\n        self.network.eval()\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo.evaluate","title":"evaluate(input_data, ground_truth, weights=[1.0, 0.1])","text":"

Internal function for evaluating.

Source code in odak/learn/wave/models.py
def evaluate(self, input_data, ground_truth, weights = [1., 0.1]):\n    \"\"\"\n    Internal function for evaluating.\n    \"\"\"\n    loss = weights[0] * self.l2(input_data, ground_truth) + weights[1] * self.l1(input_data, ground_truth)\n    return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo.fit","title":"fit(dataloader, number_of_epochs=100, learning_rate=1e-05, directory='./output', save_at_every=100)","text":"

Function to train the weights of the multi layer perceptron.

Parameters:

  • dataloader \u2013
               Data loader.\n
  • number_of_epochs (int, default: 100 ) \u2013
               Number of epochs.\n
  • learning_rate \u2013
               Learning rate of the optimizer.\n
  • directory \u2013
               Output directory.\n
  • save_at_every \u2013
               Save the model at every given epoch count.\n
Source code in odak/learn/wave/models.py
def fit(self, dataloader, number_of_epochs = 100, learning_rate = 1e-5, directory = './output', save_at_every = 100):\n    \"\"\"\n    Function to train the weights of the multi layer perceptron.\n\n    Parameters\n    ----------\n    dataloader       : torch.utils.data.DataLoader\n                       Data loader.\n    number_of_epochs : int\n                       Number of epochs.\n    learning_rate    : float\n                       Learning rate of the optimizer.\n    directory        : str\n                       Output directory.\n    save_at_every    : int\n                       Save the model at every given epoch count.\n    \"\"\"\n    t_epoch = tqdm(range(number_of_epochs), leave=False, dynamic_ncols = True)\n    self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)\n    for i in t_epoch:\n        epoch_loss = 0.\n        t_data = tqdm(dataloader, leave=False, dynamic_ncols = True)\n        for j, data in enumerate(t_data):\n            self.optimizer.zero_grad()\n            images, holograms = data\n            estimates = self.forward(images)\n            loss = self.evaluate(estimates, holograms)\n            loss.backward(retain_graph=True)\n            self.optimizer.step()\n            description = 'Loss:{:.4f}'.format(loss.item())\n            t_data.set_description(description)\n            epoch_loss += float(loss.item()) / dataloader.__len__()\n        description = 'Epoch Loss:{:.4f}'.format(epoch_loss)\n        t_epoch.set_description(description)\n        if i % save_at_every == 0:\n            self.save_weights(filename='{}/weights_{:04d}.pt'.format(directory, i))\n    self.save_weights(filename='{}/weights.pt'.format(directory))\n    print(description)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo.forward","title":"forward(x, test=False)","text":"

Internal function representing the forward model.

Source code in odak/learn/wave/models.py
def forward(self, x, test = False):\n    \"\"\"\n    Internal function representing the forward model.\n    \"\"\"\n    if test:\n        torch.no_grad()\n    y = self.network.forward(x) \n    phase_low = y[:, 0].unsqueeze(1)\n    phase_high = y[:, 1].unsqueeze(1)\n    phase_only = torch.zeros_like(phase_low)\n    phase_only[:, :, 0::2, 0::2] = phase_low[:, :,  0::2, 0::2]\n    phase_only[:, :, 1::2, 1::2] = phase_low[:, :, 1::2, 1::2]\n    phase_only[:, :, 0::2, 1::2] = phase_high[:, :, 0::2, 1::2]\n    phase_only[:, :, 1::2, 0::2] = phase_high[:, :, 1::2, 0::2]\n    return phase_only\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo.load_weights","title":"load_weights(filename='./weights.pt')","text":"

Function to load weights for this multi layer perceptron from a file.

Parameters:

  • filename \u2013
              Filename.\n
Source code in odak/learn/wave/models.py
def load_weights(self, filename = './weights.pt'):\n    \"\"\"\n    Function to load weights for this multi layer perceptron from a file.\n    Parameters\n    ----------\n    filename        : str\n                      Filename.\n    \"\"\"\n    self.network.load_state_dict(torch.load(os.path.expanduser(filename)))\n    self.network.eval()\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.holobeam_multiholo.save_weights","title":"save_weights(filename='./weights.pt')","text":"

Function to save the current weights of the multi layer perceptron to a file.

Parameters:

  • filename \u2013
              Filename.\n
Source code in odak/learn/wave/models.py
def save_weights(self, filename = './weights.pt'):\n    \"\"\"\n    Function to save the current weights of the multi layer perceptron to a file.\n    Parameters\n    ----------\n    filename        : str\n                      Filename.\n    \"\"\"\n    torch.save(self.network.state_dict(), os.path.expanduser(filename))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.multi_layer_perceptron","title":"multi_layer_perceptron","text":"

Bases: Module

A multi-layer perceptron model.

Source code in odak/learn/models/models.py
class multi_layer_perceptron(torch.nn.Module):\n    \"\"\"\n    A multi-layer perceptron model.\n    \"\"\"\n\n    def __init__(self,\n                 dimensions,\n                 activation = torch.nn.ReLU(),\n                 bias = False,\n                 model_type = 'conventional',\n                 siren_multiplier = 1.,\n                 input_multiplier = None\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        dimensions        : list\n                            List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n        activation        : torch.nn\n                            Nonlinear activation function.\n                            Default is `torch.nn.ReLU()`.\n        bias              : bool\n                            If set to True, linear layers will include biases.\n        siren_multiplier  : float\n                            When using `SIREN` model type, this parameter functions as a hyperparameter.\n                            The original SIREN work uses 30.\n                            You can bypass this parameter by providing input that are not normalized and larger then one.\n        input_multiplier  : float\n                            Initial value of the input multiplier before the very first layer.\n        model_type        : str\n                            Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                            `conventional` refers to a standard multi layer perceptron.\n                            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n        \"\"\"\n        super(multi_layer_perceptron, self).__init__()\n        self.activation = activation\n        self.bias = bias\n        self.model_type = model_type\n        self.layers = torch.nn.ModuleList()\n        self.siren_multiplier = siren_multiplier\n        self.dimensions = dimensions\n        for i in range(len(self.dimensions) - 1):\n            self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n        if not isinstance(input_multiplier, type(None)):\n            self.input_multiplier = torch.nn.ParameterList()\n            self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n        if self.model_type == 'FILM SIREN':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n        if self.model_type == 'Gaussian':\n            self.alpha = torch.nn.ParameterList()\n            for j in self.dimensions[1:-1]:\n                self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        if hasattr(self, 'input_multiplier'):\n            result = x * self.input_multiplier[0]\n        else:\n            result = x\n        for layer_id, layer in enumerate(self.layers[:-1]):\n            result = layer(result)\n            if self.model_type == 'conventional':\n                result = self.activation(result)\n            elif self.model_type == 'swish':\n                resutl = swish(result)\n            elif self.model_type == 'SIREN':\n                result = torch.sin(result * self.siren_multiplier)\n            elif self.model_type == 'FILM SIREN':\n                result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n            elif self.model_type == 'Gaussian': \n                result = gaussian(result, self.alpha[layer_id][0])\n        result = self.layers[-1](result)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.multi_layer_perceptron.__init__","title":"__init__(dimensions, activation=torch.nn.ReLU(), bias=False, model_type='conventional', siren_multiplier=1.0, input_multiplier=None)","text":"

Parameters:

  • dimensions \u2013
                List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n
  • activation \u2013
                Nonlinear activation function.\n            Default is `torch.nn.ReLU()`.\n
  • bias \u2013
                If set to True, linear layers will include biases.\n
  • siren_multiplier \u2013
                When using `SIREN` model type, this parameter functions as a hyperparameter.\n            The original SIREN work uses 30.\n            You can bypass this parameter by providing input that are not normalized and larger then one.\n
  • input_multiplier \u2013
                Initial value of the input multiplier before the very first layer.\n
  • model_type \u2013
                Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n            `conventional` refers to a standard multi layer perceptron.\n            For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n            For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n            For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n            For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n
Source code in odak/learn/models/models.py
def __init__(self,\n             dimensions,\n             activation = torch.nn.ReLU(),\n             bias = False,\n             model_type = 'conventional',\n             siren_multiplier = 1.,\n             input_multiplier = None\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    dimensions        : list\n                        List of integers representing the dimensions of each layer (e.g., [2, 10, 1], where the first layer has two channels and last one has one channel.).\n    activation        : torch.nn\n                        Nonlinear activation function.\n                        Default is `torch.nn.ReLU()`.\n    bias              : bool\n                        If set to True, linear layers will include biases.\n    siren_multiplier  : float\n                        When using `SIREN` model type, this parameter functions as a hyperparameter.\n                        The original SIREN work uses 30.\n                        You can bypass this parameter by providing input that are not normalized and larger then one.\n    input_multiplier  : float\n                        Initial value of the input multiplier before the very first layer.\n    model_type        : str\n                        Model type: `conventional`, `swish`, `SIREN`, `FILM SIREN`, `Gaussian`.\n                        `conventional` refers to a standard multi layer perceptron.\n                        For `SIREN,` see: Sitzmann, Vincent, et al. \"Implicit neural representations with periodic activation functions.\" Advances in neural information processing systems 33 (2020): 7462-7473.\n                        For `Swish,` see: Ramachandran, Prajit, Barret Zoph, and Quoc V. Le. \"Searching for activation functions.\" arXiv preprint arXiv:1710.05941 (2017). \n                        For `FILM SIREN,` see: Chan, Eric R., et al. \"pi-gan: Periodic implicit generative adversarial networks for 3d-aware image synthesis.\" Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.\n                        For `Gaussian,` see: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n    \"\"\"\n    super(multi_layer_perceptron, self).__init__()\n    self.activation = activation\n    self.bias = bias\n    self.model_type = model_type\n    self.layers = torch.nn.ModuleList()\n    self.siren_multiplier = siren_multiplier\n    self.dimensions = dimensions\n    for i in range(len(self.dimensions) - 1):\n        self.layers.append(torch.nn.Linear(self.dimensions[i], self.dimensions[i + 1], bias = self.bias))\n    if not isinstance(input_multiplier, type(None)):\n        self.input_multiplier = torch.nn.ParameterList()\n        self.input_multiplier.append(torch.nn.Parameter(torch.ones(1, self.dimensions[0]) * input_multiplier))\n    if self.model_type == 'FILM SIREN':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(2, 1, j)))\n    if self.model_type == 'Gaussian':\n        self.alpha = torch.nn.ParameterList()\n        for j in self.dimensions[1:-1]:\n            self.alpha.append(torch.nn.Parameter(torch.randn(1, 1, j)))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.multi_layer_perceptron.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    if hasattr(self, 'input_multiplier'):\n        result = x * self.input_multiplier[0]\n    else:\n        result = x\n    for layer_id, layer in enumerate(self.layers[:-1]):\n        result = layer(result)\n        if self.model_type == 'conventional':\n            result = self.activation(result)\n        elif self.model_type == 'swish':\n            resutl = swish(result)\n        elif self.model_type == 'SIREN':\n            result = torch.sin(result * self.siren_multiplier)\n        elif self.model_type == 'FILM SIREN':\n            result = torch.sin(self.alpha[layer_id][0] * result + self.alpha[layer_id][1])\n        elif self.model_type == 'Gaussian': \n            result = gaussian(result, self.alpha[layer_id][0])\n    result = self.layers[-1](result)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.non_local_layer","title":"non_local_layer","text":"

Bases: Module

Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)

Source code in odak/learn/models/components.py
class non_local_layer(torch.nn.Module):\n    \"\"\"\n    Self-Attention Layer [zi = Wzyi + xi] (non-local block : ref https://arxiv.org/abs/1711.07971)\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 1024,\n                 bottleneck_channels = 512,\n                 kernel_size = 1,\n                 bias = False,\n                ):\n        \"\"\"\n\n        Parameters\n        ----------\n        input_channels      : int\n                              Number of input channels.\n        bottleneck_channels : int\n                              Number of middle channels.\n        kernel_size         : int\n                              Kernel size.\n        bias                : bool \n                              Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super(non_local_layer, self).__init__()\n        self.input_channels = input_channels\n        self.bottleneck_channels = bottleneck_channels\n        self.g = torch.nn.Conv2d(\n                                 self.input_channels, \n                                 self.bottleneck_channels,\n                                 kernel_size = kernel_size,\n                                 padding = kernel_size // 2,\n                                 bias = bias\n                                )\n        self.W_z = torch.nn.Sequential(\n                                       torch.nn.Conv2d(\n                                                       self.bottleneck_channels,\n                                                       self.input_channels, \n                                                       kernel_size = kernel_size,\n                                                       bias = bias,\n                                                       padding = kernel_size // 2\n                                                      ),\n                                       torch.nn.BatchNorm2d(self.input_channels)\n                                      )\n        torch.nn.init.constant_(self.W_z[1].weight, 0)   \n        torch.nn.init.constant_(self.W_z[1].bias, 0)\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model [zi = Wzyi + xi]\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          First input data.                       \n\n\n        Returns\n        ----------\n        z               : torch.tensor\n                          Estimated output.\n        \"\"\"\n        batch_size, channels, height, width = x.size()\n        theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n        g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n        attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n        attn = torch.nn.functional.softmax(attn, dim=-1)\n        y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n        W_y = self.W_z(y)\n        z = W_y + x\n        return z\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.non_local_layer.__init__","title":"__init__(input_channels=1024, bottleneck_channels=512, kernel_size=1, bias=False)","text":"

Parameters:

  • input_channels \u2013
                  Number of input channels.\n
  • bottleneck_channels (int, default: 512 ) \u2013
                  Number of middle channels.\n
  • kernel_size \u2013
                  Kernel size.\n
  • bias \u2013
                  Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 1024,\n             bottleneck_channels = 512,\n             kernel_size = 1,\n             bias = False,\n            ):\n    \"\"\"\n\n    Parameters\n    ----------\n    input_channels      : int\n                          Number of input channels.\n    bottleneck_channels : int\n                          Number of middle channels.\n    kernel_size         : int\n                          Kernel size.\n    bias                : bool \n                          Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super(non_local_layer, self).__init__()\n    self.input_channels = input_channels\n    self.bottleneck_channels = bottleneck_channels\n    self.g = torch.nn.Conv2d(\n                             self.input_channels, \n                             self.bottleneck_channels,\n                             kernel_size = kernel_size,\n                             padding = kernel_size // 2,\n                             bias = bias\n                            )\n    self.W_z = torch.nn.Sequential(\n                                   torch.nn.Conv2d(\n                                                   self.bottleneck_channels,\n                                                   self.input_channels, \n                                                   kernel_size = kernel_size,\n                                                   bias = bias,\n                                                   padding = kernel_size // 2\n                                                  ),\n                                   torch.nn.BatchNorm2d(self.input_channels)\n                                  )\n    torch.nn.init.constant_(self.W_z[1].weight, 0)   \n    torch.nn.init.constant_(self.W_z[1].bias, 0)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.non_local_layer.forward","title":"forward(x)","text":"

Forward model [zi = Wzyi + xi]

Parameters:

  • x \u2013
              First input data.\n

Returns:

  • z ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model [zi = Wzyi + xi]\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      First input data.                       \n\n\n    Returns\n    ----------\n    z               : torch.tensor\n                      Estimated output.\n    \"\"\"\n    batch_size, channels, height, width = x.size()\n    theta = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    phi = x.view(batch_size, channels, -1).permute(0, 2, 1)\n    g = self.g(x).view(batch_size, self.bottleneck_channels, -1).permute(0, 2, 1)\n    attn = torch.bmm(theta, phi.transpose(1, 2)) / (height * width)\n    attn = torch.nn.functional.softmax(attn, dim=-1)\n    y = torch.bmm(attn, g).permute(0, 2, 1).contiguous().view(batch_size, self.bottleneck_channels, height, width)\n    W_y = self.W_z(y)\n    z = W_y + x\n    return z\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.normalization","title":"normalization","text":"

Bases: Module

A normalization layer.

Source code in odak/learn/models/components.py
class normalization(torch.nn.Module):\n    \"\"\"\n    A normalization layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 dim = 1,\n                ):\n        \"\"\"\n        Normalization layer.\n\n\n        Parameters\n        ----------\n        dim             : int\n                          Dimension (axis) to normalize.\n        \"\"\"\n        super().__init__()\n        self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = 1, keepdim = True)\n        result =  (x - mean) * (var + eps).rsqrt() * self.k\n        return result \n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.normalization.__init__","title":"__init__(dim=1)","text":"

Normalization layer.

Parameters:

  • dim \u2013
              Dimension (axis) to normalize.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             dim = 1,\n            ):\n    \"\"\"\n    Normalization layer.\n\n\n    Parameters\n    ----------\n    dim             : int\n                      Dimension (axis) to normalize.\n    \"\"\"\n    super().__init__()\n    self.k = torch.nn.Parameter(torch.ones(1, dim, 1, 1))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.normalization.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    eps = 1e-5 if x.dtype == torch.float32 else 1e-3\n    var = torch.var(x, dim = 1, unbiased = False, keepdim = True)\n    mean = torch.mean(x, dim = 1, keepdim = True)\n    result =  (x - mean) * (var + eps).rsqrt() * self.k\n    return result \n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.positional_encoder","title":"positional_encoder","text":"

Bases: Module

A positional encoder module.

Source code in odak/learn/models/components.py
class positional_encoder(torch.nn.Module):\n    \"\"\"\n    A positional encoder module.\n    \"\"\"\n\n    def __init__(self, L):\n        \"\"\"\n        A positional encoder module.\n\n        Parameters\n        ----------\n        L                   : int\n                              Positional encoding level.\n        \"\"\"\n        super(positional_encoder, self).__init__()\n        self.L = L\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x               : torch.tensor\n                          Input data.\n\n        Returns\n        ----------\n        result          : torch.tensor\n                          Result of the forward operation\n        \"\"\"\n        B, C = x.shape\n        x = x.view(B, C, 1)\n        results = [x]\n        for i in range(1, self.L + 1):\n            freq = (2 ** i) * math.pi\n            cos_x = torch.cos(freq * x)\n            sin_x = torch.sin(freq * x)\n            results.append(cos_x)\n            results.append(sin_x)\n        results = torch.cat(results, dim=2)\n        results = results.permute(0, 2, 1)\n        results = results.reshape(B, -1)\n        return results\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.positional_encoder.__init__","title":"__init__(L)","text":"

A positional encoder module.

Parameters:

  • L \u2013
                  Positional encoding level.\n
Source code in odak/learn/models/components.py
def __init__(self, L):\n    \"\"\"\n    A positional encoder module.\n\n    Parameters\n    ----------\n    L                   : int\n                          Positional encoding level.\n    \"\"\"\n    super(positional_encoder, self).__init__()\n    self.L = L\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.positional_encoder.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
              Input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x               : torch.tensor\n                      Input data.\n\n    Returns\n    ----------\n    result          : torch.tensor\n                      Result of the forward operation\n    \"\"\"\n    B, C = x.shape\n    x = x.view(B, C, 1)\n    results = [x]\n    for i in range(1, self.L + 1):\n        freq = (2 ** i) * math.pi\n        cos_x = torch.cos(freq * x)\n        sin_x = torch.sin(freq * x)\n        results.append(cos_x)\n        results.append(sin_x)\n    results = torch.cat(results, dim=2)\n    results = results.permute(0, 2, 1)\n    results = results.reshape(B, -1)\n    return results\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_attention_layer","title":"residual_attention_layer","text":"

Bases: Module

A residual block with an attention layer.

Source code in odak/learn/models/components.py
class residual_attention_layer(torch.nn.Module):\n    \"\"\"\n    A residual block with an attention layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 1,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        An attention layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int or optioal\n                          Number of input channels.\n        output_channels : int or optional\n                          Number of middle channels.\n        kernel_size     : int or optional\n                          Kernel size.\n        bias            : bool or optional\n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn or optional\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution0 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.convolution1 = torch.nn.Sequential(\n                                                torch.nn.Conv2d(\n                                                                input_channels,\n                                                                output_channels,\n                                                                kernel_size = kernel_size,\n                                                                padding = kernel_size // 2,\n                                                                bias = bias\n                                                               ),\n                                                torch.nn.BatchNorm2d(output_channels)\n                                               )\n        self.final_layer = torch.nn.Sequential(\n                                               self.activation,\n                                               torch.nn.Conv2d(\n                                                               output_channels,\n                                                               output_channels,\n                                                               kernel_size = kernel_size,\n                                                               padding = kernel_size // 2,\n                                                               bias = bias\n                                                              )\n                                              )\n\n\n    def forward(self, x0, x1):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x0             : torch.tensor\n                         First input data.\n\n        x1             : torch.tensor\n                         Seconnd input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        y0 = self.convolution0(x0)\n        y1 = self.convolution1(x1)\n        y2 = torch.add(y0, y1)\n        result = self.final_layer(y2) * x0\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_attention_layer.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=1, bias=False, activation=torch.nn.ReLU())","text":"

An attention layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int or optional, default: 2 ) \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 1,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    An attention layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int or optioal\n                      Number of input channels.\n    output_channels : int or optional\n                      Number of middle channels.\n    kernel_size     : int or optional\n                      Kernel size.\n    bias            : bool or optional\n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn or optional\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution0 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.convolution1 = torch.nn.Sequential(\n                                            torch.nn.Conv2d(\n                                                            input_channels,\n                                                            output_channels,\n                                                            kernel_size = kernel_size,\n                                                            padding = kernel_size // 2,\n                                                            bias = bias\n                                                           ),\n                                            torch.nn.BatchNorm2d(output_channels)\n                                           )\n    self.final_layer = torch.nn.Sequential(\n                                           self.activation,\n                                           torch.nn.Conv2d(\n                                                           output_channels,\n                                                           output_channels,\n                                                           kernel_size = kernel_size,\n                                                           padding = kernel_size // 2,\n                                                           bias = bias\n                                                          )\n                                          )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_attention_layer.forward","title":"forward(x0, x1)","text":"

Forward model.

Parameters:

  • x0 \u2013
             First input data.\n
  • x1 \u2013
             Seconnd input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x0, x1):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x0             : torch.tensor\n                     First input data.\n\n    x1             : torch.tensor\n                     Seconnd input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    y0 = self.convolution0(x0)\n    y1 = self.convolution1(x1)\n    y2 = torch.add(y0, y1)\n    result = self.final_layer(y2) * x0\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_layer","title":"residual_layer","text":"

Bases: Module

A residual layer.

Source code in odak/learn/models/components.py
class residual_layer(torch.nn.Module):\n    \"\"\"\n    A residual layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 mid_channels = 16,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU()\n                ):\n        \"\"\"\n        A convolutional layer class.\n\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        mid_channels    : int\n                          Number of middle channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super().__init__()\n        self.activation = activation\n        self.convolution = double_convolution(\n                                              input_channels,\n                                              mid_channels = mid_channels,\n                                              output_channels = input_channels,\n                                              kernel_size = kernel_size,\n                                              bias = bias,\n                                              activation = activation\n                                             )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        x0 = self.convolution(x)\n        return x + x0\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_layer.__init__","title":"__init__(input_channels=2, mid_channels=16, kernel_size=3, bias=False, activation=torch.nn.ReLU())","text":"

A convolutional layer class.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • mid_channels \u2013
              Number of middle channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             mid_channels = 16,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU()\n            ):\n    \"\"\"\n    A convolutional layer class.\n\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    mid_channels    : int\n                      Number of middle channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super().__init__()\n    self.activation = activation\n    self.convolution = double_convolution(\n                                          input_channels,\n                                          mid_channels = mid_channels,\n                                          output_channels = input_channels,\n                                          kernel_size = kernel_size,\n                                          bias = bias,\n                                          activation = activation\n                                         )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.residual_layer.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    x0 = self.convolution(x)\n    return x + x0\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatial_gate","title":"spatial_gate","text":"

Bases: Module

Spatial attention module that applies a convolution layer after channel pooling. This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.

Source code in odak/learn/models/components.py
class spatial_gate(torch.nn.Module):\n    \"\"\"\n    Spatial attention module that applies a convolution layer after channel pooling.\n    This class is heavily inspired by https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py.\n    \"\"\"\n    def __init__(self):\n        \"\"\"\n        Initializes the spatial gate module.\n        \"\"\"\n        super().__init__()\n        kernel_size = 7\n        self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n\n\n    def channel_pool(self, x):\n        \"\"\"\n        Applies max and average pooling on the channels.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input tensor.\n\n        Returns\n        -------\n        output        : torch.tensor\n                        Output tensor.\n        \"\"\"\n        max_pool = torch.max(x, 1)[0].unsqueeze(1)\n        avg_pool = torch.mean(x, 1).unsqueeze(1)\n        output = torch.cat((max_pool, avg_pool), dim=1)\n        return output\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass of the SpatialGate module.\n\n        Applies spatial attention to the input tensor.\n\n        Parameters\n        ----------\n        x            : torch.tensor\n                       Input tensor to the SpatialGate module.\n\n        Returns\n        -------\n        scaled_x     : torch.tensor\n                       Output tensor after applying spatial attention.\n        \"\"\"\n        x_compress = self.channel_pool(x)\n        x_out = self.spatial(x_compress)\n        scale = torch.sigmoid(x_out)\n        scaled_x = x * scale\n        return scaled_x\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatial_gate.__init__","title":"__init__()","text":"

Initializes the spatial gate module.

Source code in odak/learn/models/components.py
def __init__(self):\n    \"\"\"\n    Initializes the spatial gate module.\n    \"\"\"\n    super().__init__()\n    kernel_size = 7\n    self.spatial = convolution_layer(2, 1, kernel_size, bias = False, activation = torch.nn.Identity())\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatial_gate.channel_pool","title":"channel_pool(x)","text":"

Applies max and average pooling on the channels.

Parameters:

  • x \u2013
            Input tensor.\n

Returns:

  • output ( tensor ) \u2013

    Output tensor.

Source code in odak/learn/models/components.py
def channel_pool(self, x):\n    \"\"\"\n    Applies max and average pooling on the channels.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input tensor.\n\n    Returns\n    -------\n    output        : torch.tensor\n                    Output tensor.\n    \"\"\"\n    max_pool = torch.max(x, 1)[0].unsqueeze(1)\n    avg_pool = torch.mean(x, 1).unsqueeze(1)\n    output = torch.cat((max_pool, avg_pool), dim=1)\n    return output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatial_gate.forward","title":"forward(x)","text":"

Forward pass of the SpatialGate module.

Applies spatial attention to the input tensor.

Parameters:

  • x \u2013
           Input tensor to the SpatialGate module.\n

Returns:

  • scaled_x ( tensor ) \u2013

    Output tensor after applying spatial attention.

Source code in odak/learn/models/components.py
def forward(self, x):\n    \"\"\"\n    Forward pass of the SpatialGate module.\n\n    Applies spatial attention to the input tensor.\n\n    Parameters\n    ----------\n    x            : torch.tensor\n                   Input tensor to the SpatialGate module.\n\n    Returns\n    -------\n    scaled_x     : torch.tensor\n                   Output tensor after applying spatial attention.\n    \"\"\"\n    x_compress = self.channel_pool(x)\n    x_out = self.spatial(x_compress)\n    scale = torch.sigmoid(x_out)\n    scaled_x = x * scale\n    return scaled_x\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_convolution","title":"spatially_adaptive_convolution","text":"

Bases: Module

A spatially adaptive convolution layer.

References

C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\" C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\" C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"

Source code in odak/learn/models/components.py
class spatially_adaptive_convolution(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive convolution layer.\n\n    References\n    ----------\n\n    C. Zheng et al. \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions.\"\n    C. Xu et al. \"Squeezesegv3: Spatially-adaptive Convolution for Efficient Point-Cloud Segmentation.\"\n    C. Zheng et al. \"Windowing Decomposition Convolutional Neural Network for Image Enhancement.\"\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn.Module\n                          Activation function to apply. If None, no activation is applied.\n        \"\"\"\n        super(spatially_adaptive_convolution, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive convolution layer.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        sa_output          : torch.tensor\n                            Estimated output tensor.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Resize weight to match the input channels and kernel size\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply spatially varying kernels\n        sv_feature = input_feature * sv_kernel\n\n        # Perform matrix multiplication\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n        return sa_output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_convolution.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive convolution layer.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Activation function to apply. If None, no activation is applied.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn.Module\n                      Activation function to apply. If None, no activation is applied.\n    \"\"\"\n    super(spatially_adaptive_convolution, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_convolution.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive convolution layer.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • sa_output ( tensor ) \u2013

    Estimated output tensor. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive convolution layer.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    sa_output          : torch.tensor\n                        Estimated output tensor.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Resize weight to match the input channels and kernel size\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply spatially varying kernels\n    sv_feature = input_feature * sv_kernel\n\n    # Perform matrix multiplication\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n    return sa_output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_module","title":"spatially_adaptive_module","text":"

Bases: Module

A spatially adaptive module that combines learned spatially adaptive convolutions.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/components.py
class spatially_adaptive_module(torch.nn.Module):\n    \"\"\"\n    A spatially adaptive module that combines learned spatially adaptive convolutions.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels = 2,\n                 output_channels = 2,\n                 kernel_size = 3,\n                 stride = 1,\n                 padding = 1,\n                 bias = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        Initializes a spatially adaptive module.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Size of the convolution kernel.\n        stride          : int\n                          Stride of the convolution.\n        padding         : int\n                          Padding added to both sides of the input.\n        bias            : bool\n                          If True, includes a bias term in the convolution.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        \"\"\"\n        super(spatially_adaptive_module, self).__init__()\n        self.kernel_size = kernel_size\n        self.input_channels = input_channels\n        self.output_channels = output_channels\n        self.stride = stride\n        self.padding = padding\n        self.weight_output_channels = self.output_channels - 1\n        self.standard_convolution = torch.nn.Conv2d(\n                                                    in_channels = input_channels,\n                                                    out_channels = self.weight_output_channels,\n                                                    kernel_size = kernel_size,\n                                                    stride = stride,\n                                                    padding = padding,\n                                                    bias = bias\n                                                   )\n        self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n        self.activation = activation\n\n\n    def forward(self, x, sv_kernel_feature):\n        \"\"\"\n        Forward pass for the spatially adaptive module.\n\n        Parameters\n        ----------\n        x                  : torch.tensor\n                            Input data tensor.\n                            Dimension: (1, C, H, W)\n        sv_kernel_feature   : torch.tensor\n                            Spatially varying kernel features.\n                            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n        Returns\n        -------\n        output             : torch.tensor\n                            Combined output tensor from standard and spatially adaptive convolutions.\n                            Dimension: (1, output_channels, H_out, W_out)\n        \"\"\"\n        # Pad input and sv_kernel_feature if necessary\n        if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n                -2) * self.stride != x.size(-2):\n            diffY = sv_kernel_feature.size(-2) % self.stride\n            diffX = sv_kernel_feature.size(-1) % self.stride\n            sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                            diffY // 2, diffY - diffY // 2))\n            diffY = x.size(-2) % self.stride\n            diffX = x.size(-1) % self.stride\n            x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                            diffY // 2, diffY - diffY // 2))\n\n        # Unfold the input tensor for matrix multiplication\n        input_feature = torch.nn.functional.unfold(\n                                                   x,\n                                                   kernel_size = (self.kernel_size, self.kernel_size),\n                                                   stride = self.stride,\n                                                   padding = self.padding\n                                                  )\n\n        # Resize sv_kernel_feature to match the input feature\n        sv_kernel = sv_kernel_feature.reshape(\n                                              1,\n                                              self.input_channels * self.kernel_size * self.kernel_size,\n                                              (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                             )\n\n        # Apply sv_kernel to the input_feature\n        sv_feature = input_feature * sv_kernel\n\n        # Original spatially varying convolution output\n        sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                           1,\n                                                            1,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n        # Reshape weight for spatially adaptive convolution\n        si_kernel = self.weight.reshape(\n                                        self.weight_output_channels,\n                                        self.input_channels * self.kernel_size * self.kernel_size\n                                       )\n\n        # Apply si_kernel on sv convolution output\n        sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                                1, self.weight_output_channels,\n                                                                (x.size(-2) // self.stride),\n                                                                (x.size(-1) // self.stride)\n                                                               )\n\n        # Combine the outputs and apply activation function\n        output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n        return output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_module.__init__","title":"__init__(input_channels=2, output_channels=2, kernel_size=3, stride=1, padding=1, bias=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

Initializes a spatially adaptive module.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int, default: 2 ) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Size of the convolution kernel.\n
  • stride \u2013
              Stride of the convolution.\n
  • padding \u2013
              Padding added to both sides of the input.\n
  • bias \u2013
              If True, includes a bias term in the convolution.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels = 2,\n             output_channels = 2,\n             kernel_size = 3,\n             stride = 1,\n             padding = 1,\n             bias = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    Initializes a spatially adaptive module.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Size of the convolution kernel.\n    stride          : int\n                      Stride of the convolution.\n    padding         : int\n                      Padding added to both sides of the input.\n    bias            : bool\n                      If True, includes a bias term in the convolution.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    \"\"\"\n    super(spatially_adaptive_module, self).__init__()\n    self.kernel_size = kernel_size\n    self.input_channels = input_channels\n    self.output_channels = output_channels\n    self.stride = stride\n    self.padding = padding\n    self.weight_output_channels = self.output_channels - 1\n    self.standard_convolution = torch.nn.Conv2d(\n                                                in_channels = input_channels,\n                                                out_channels = self.weight_output_channels,\n                                                kernel_size = kernel_size,\n                                                stride = stride,\n                                                padding = padding,\n                                                bias = bias\n                                               )\n    self.weight = torch.nn.Parameter(data = self.standard_convolution.weight, requires_grad = True)\n    self.activation = activation\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_module.forward","title":"forward(x, sv_kernel_feature)","text":"

Forward pass for the spatially adaptive module.

Parameters:

  • x \u2013
                Input data tensor.\n            Dimension: (1, C, H, W)\n
  • sv_kernel_feature \u2013
                Spatially varying kernel features.\n            Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n

Returns:

  • output ( tensor ) \u2013

    Combined output tensor from standard and spatially adaptive convolutions. Dimension: (1, output_channels, H_out, W_out)

Source code in odak/learn/models/components.py
def forward(self, x, sv_kernel_feature):\n    \"\"\"\n    Forward pass for the spatially adaptive module.\n\n    Parameters\n    ----------\n    x                  : torch.tensor\n                        Input data tensor.\n                        Dimension: (1, C, H, W)\n    sv_kernel_feature   : torch.tensor\n                        Spatially varying kernel features.\n                        Dimension: (1, C_i * kernel_size * kernel_size, H, W)\n\n    Returns\n    -------\n    output             : torch.tensor\n                        Combined output tensor from standard and spatially adaptive convolutions.\n                        Dimension: (1, output_channels, H_out, W_out)\n    \"\"\"\n    # Pad input and sv_kernel_feature if necessary\n    if sv_kernel_feature.size(-1) * self.stride != x.size(-1) or sv_kernel_feature.size(\n            -2) * self.stride != x.size(-2):\n        diffY = sv_kernel_feature.size(-2) % self.stride\n        diffX = sv_kernel_feature.size(-1) % self.stride\n        sv_kernel_feature = torch.nn.functional.pad(sv_kernel_feature, (diffX // 2, diffX - diffX // 2,\n                                                                        diffY // 2, diffY - diffY // 2))\n        diffY = x.size(-2) % self.stride\n        diffX = x.size(-1) % self.stride\n        x = torch.nn.functional.pad(x, (diffX // 2, diffX - diffX // 2,\n                                        diffY // 2, diffY - diffY // 2))\n\n    # Unfold the input tensor for matrix multiplication\n    input_feature = torch.nn.functional.unfold(\n                                               x,\n                                               kernel_size = (self.kernel_size, self.kernel_size),\n                                               stride = self.stride,\n                                               padding = self.padding\n                                              )\n\n    # Resize sv_kernel_feature to match the input feature\n    sv_kernel = sv_kernel_feature.reshape(\n                                          1,\n                                          self.input_channels * self.kernel_size * self.kernel_size,\n                                          (x.size(-2) // self.stride) * (x.size(-1) // self.stride)\n                                         )\n\n    # Apply sv_kernel to the input_feature\n    sv_feature = input_feature * sv_kernel\n\n    # Original spatially varying convolution output\n    sv_output = torch.sum(sv_feature, dim = 1).reshape(\n                                                       1,\n                                                        1,\n                                                        (x.size(-2) // self.stride),\n                                                        (x.size(-1) // self.stride)\n                                                       )\n\n    # Reshape weight for spatially adaptive convolution\n    si_kernel = self.weight.reshape(\n                                    self.weight_output_channels,\n                                    self.input_channels * self.kernel_size * self.kernel_size\n                                   )\n\n    # Apply si_kernel on sv convolution output\n    sa_output = torch.matmul(si_kernel, sv_feature).reshape(\n                                                            1, self.weight_output_channels,\n                                                            (x.size(-2) // self.stride),\n                                                            (x.size(-1) // self.stride)\n                                                           )\n\n    # Combine the outputs and apply activation function\n    output = self.activation(torch.cat((sv_output, sa_output), dim = 1))\n    return output\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_unet","title":"spatially_adaptive_unet","text":"

Bases: Module

Spatially varying U-Net model based on spatially adaptive convolution.

References

Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.

Source code in odak/learn/models/models.py
class spatially_adaptive_unet(torch.nn.Module):\n    \"\"\"\n    Spatially varying U-Net model based on spatially adaptive convolution.\n\n    References\n    ----------\n\n    Chuanjun Zheng, Yicheng Zhan, Liang Shi, Ozan Cakmakci, and Kaan Ak\u015fit, \"Focal Surface Holographic Light Transport using Learned Spatially Adaptive Convolutions,\" SIGGRAPH Asia 2024 Technical Communications (SA Technical Communications '24), December, 2024.\n    \"\"\"\n    def __init__(\n                 self,\n                 depth=3,\n                 dimensions=8,\n                 input_channels=6,\n                 out_channels=6,\n                 kernel_size=3,\n                 bias=True,\n                 normalization=False,\n                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        out_channels   : int\n                         Number of output channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.out_channels = out_channels\n        self.inc = convolution_layer(\n                                     input_channels=input_channels,\n                                     output_channels=dimensions,\n                                     kernel_size=kernel_size,\n                                     bias=bias,\n                                     normalization=normalization,\n                                     activation=activation\n                                    )\n\n        self.encoder = torch.nn.ModuleList()\n        for i in range(self.depth + 1):  # Downsampling layers\n            down_in_channels = dimensions * (2 ** i)\n            down_out_channels = 2 * down_in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels=down_in_channels,\n                                                          mid_channels=down_in_channels,\n                                                          output_channels=down_in_channels,\n                                                          kernel_size=kernel_size,\n                                                          bias=bias,\n                                                          normalization=normalization,\n                                                          activation=activation\n                                                         )\n            sam = spatially_adaptive_module(\n                                            input_channels=down_in_channels,\n                                            output_channels=down_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                           )\n            self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n        self.global_feature_module = torch.nn.ModuleList()\n        double_convolution_layer = double_convolution(\n                                                      input_channels=dimensions * (2 ** (depth + 1)),\n                                                      mid_channels=dimensions * (2 ** (depth + 1)),\n                                                      output_channels=dimensions * (2 ** (depth + 1)),\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        global_feature_layer = global_feature_module(\n                                                     input_channels=dimensions * (2 ** (depth + 1)),\n                                                     mid_channels=dimensions * (2 ** (depth + 1)),\n                                                     output_channels=dimensions * (2 ** (depth + 1)),\n                                                     kernel_size=kernel_size,\n                                                     bias=bias,\n                                                     activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                    )\n        self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n        self.decoder = torch.nn.ModuleList()\n        for i in range(depth, -1, -1):\n            up_in_channels = dimensions * (2 ** (i + 1))\n            up_mid_channels = up_in_channels // 2\n            if i == 0:\n                up_out_channels = self.out_channels\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = torch.nn.Sequential(\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_mid_channels,\n                                      kernel_size=kernel_size,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=activation,\n                                     ),\n                    convolution_layer(\n                                      input_channels=up_mid_channels,\n                                      output_channels=up_out_channels,\n                                      kernel_size=1,\n                                      bias=bias,\n                                      normalization=normalization,\n                                      activation=None,\n                                     )\n                )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n            else:\n                up_out_channels = up_in_channels // 2\n                upsample_layer = upsample_convtranspose2d_layer(\n                                                                input_channels=up_in_channels,\n                                                                output_channels=up_mid_channels,\n                                                                kernel_size=2,\n                                                                stride=2,\n                                                                bias=bias,\n                                                               )\n                conv_layer = double_convolution(\n                                                input_channels=up_mid_channels,\n                                                mid_channels=up_mid_channels,\n                                                output_channels=up_out_channels,\n                                                kernel_size=kernel_size,\n                                                bias=bias,\n                                                normalization=normalization,\n                                                activation=activation,\n                                               )\n                self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, sv_kernel, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n\n        field     : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        target_field : torch.tensor\n                       Estimated output.\n                       Dimension: (1, 6, H, W)\n        \"\"\"\n        x = self.inc(field)\n        downsampling_outputs = [x]\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer[0](downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n            sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n            downsampling_outputs.append(sam_output)\n        global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n        global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n        downsampling_outputs.append(global_feature)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate(self.decoder):\n            x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n            x_up = up_layer[1](x_up)\n        result = x_up\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_unet.__init__","title":"__init__(depth=3, dimensions=8, input_channels=6, out_channels=6, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 6 ) \u2013
             Number of input channels.\n
  • out_channels \u2013
             Number of output channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth=3,\n             dimensions=8,\n             input_channels=6,\n             out_channels=6,\n             kernel_size=3,\n             bias=True,\n             normalization=False,\n             activation=torch.nn.LeakyReLU(0.2, inplace=True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    out_channels   : int\n                     Number of output channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.out_channels = out_channels\n    self.inc = convolution_layer(\n                                 input_channels=input_channels,\n                                 output_channels=dimensions,\n                                 kernel_size=kernel_size,\n                                 bias=bias,\n                                 normalization=normalization,\n                                 activation=activation\n                                )\n\n    self.encoder = torch.nn.ModuleList()\n    for i in range(self.depth + 1):  # Downsampling layers\n        down_in_channels = dimensions * (2 ** i)\n        down_out_channels = 2 * down_in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels=down_in_channels,\n                                                      mid_channels=down_in_channels,\n                                                      output_channels=down_in_channels,\n                                                      kernel_size=kernel_size,\n                                                      bias=bias,\n                                                      normalization=normalization,\n                                                      activation=activation\n                                                     )\n        sam = spatially_adaptive_module(\n                                        input_channels=down_in_channels,\n                                        output_channels=down_out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                       )\n        self.encoder.append(torch.nn.ModuleList([pooling_layer, double_convolution_layer, sam]))\n    self.global_feature_module = torch.nn.ModuleList()\n    double_convolution_layer = double_convolution(\n                                                  input_channels=dimensions * (2 ** (depth + 1)),\n                                                  mid_channels=dimensions * (2 ** (depth + 1)),\n                                                  output_channels=dimensions * (2 ** (depth + 1)),\n                                                  kernel_size=kernel_size,\n                                                  bias=bias,\n                                                  normalization=normalization,\n                                                  activation=activation\n                                                 )\n    global_feature_layer = global_feature_module(\n                                                 input_channels=dimensions * (2 ** (depth + 1)),\n                                                 mid_channels=dimensions * (2 ** (depth + 1)),\n                                                 output_channels=dimensions * (2 ** (depth + 1)),\n                                                 kernel_size=kernel_size,\n                                                 bias=bias,\n                                                 activation=torch.nn.LeakyReLU(0.2, inplace=True)\n                                                )\n    self.global_feature_module.append(torch.nn.ModuleList([double_convolution_layer, global_feature_layer]))\n    self.decoder = torch.nn.ModuleList()\n    for i in range(depth, -1, -1):\n        up_in_channels = dimensions * (2 ** (i + 1))\n        up_mid_channels = up_in_channels // 2\n        if i == 0:\n            up_out_channels = self.out_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = torch.nn.Sequential(\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_mid_channels,\n                                  kernel_size=kernel_size,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=activation,\n                                 ),\n                convolution_layer(\n                                  input_channels=up_mid_channels,\n                                  output_channels=up_out_channels,\n                                  kernel_size=1,\n                                  bias=bias,\n                                  normalization=normalization,\n                                  activation=None,\n                                 )\n            )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n        else:\n            up_out_channels = up_in_channels // 2\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels=up_in_channels,\n                                                            output_channels=up_mid_channels,\n                                                            kernel_size=2,\n                                                            stride=2,\n                                                            bias=bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels=up_mid_channels,\n                                            mid_channels=up_mid_channels,\n                                            output_channels=up_out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            normalization=normalization,\n                                            activation=activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_adaptive_unet.forward","title":"forward(sv_kernel, field)","text":"

Forward model.

Parameters:

  • sv_kernel (list of torch.tensor) \u2013
        Learned spatially varying kernels.\n    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n    where C_i, H_i, and W_i represent the channel, height, and width\n    of each feature at a certain scale.\n
  • field \u2013
        Input field data.\n    Dimension: (1, 6, H, W)\n

Returns:

  • target_field ( tensor ) \u2013

    Estimated output. Dimension: (1, 6, H, W)

Source code in odak/learn/models/models.py
def forward(self, sv_kernel, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n\n    field     : torch.tensor\n                Input field data.\n                Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    target_field : torch.tensor\n                   Estimated output.\n                   Dimension: (1, 6, H, W)\n    \"\"\"\n    x = self.inc(field)\n    downsampling_outputs = [x]\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer[0](downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n        sam_output = down_layer[2](x_down + down_layer[1](x_down), sv_kernel[self.depth - i])\n        downsampling_outputs.append(sam_output)\n    global_feature = self.global_feature_module[0][0](downsampling_outputs[-1])\n    global_feature = self.global_feature_module[0][1](downsampling_outputs[-1], global_feature)\n    downsampling_outputs.append(global_feature)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate(self.decoder):\n        x_up = up_layer[0](x_up, downsampling_outputs[2 * (self.depth - i)])\n        x_up = up_layer[1](x_up)\n    result = x_up\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_varying_kernel_generation_model","title":"spatially_varying_kernel_generation_model","text":"

Bases: Module

Spatially_varying_kernel_generation_model revised from RSGUnet: https://github.com/MTLab/rsgunet_image_enhance.

Refer to: J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.

Source code in odak/learn/models/models.py
class spatially_varying_kernel_generation_model(torch.nn.Module):\n    \"\"\"\n    Spatially_varying_kernel_generation_model revised from RSGUnet:\n    https://github.com/MTLab/rsgunet_image_enhance.\n\n    Refer to:\n    J. Huang, P. Zhu, M. Geng et al. Range Scaling Global U-Net for Perceptual Image Enhancement on Mobile Devices.\n    \"\"\"\n\n    def __init__(\n                 self,\n                 depth = 3,\n                 dimensions = 8,\n                 input_channels = 7,\n                 kernel_size = 3,\n                 bias = True,\n                 normalization = False,\n                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth          : int\n                         Number of upsampling and downsampling layers.\n        dimensions     : int\n                         Number of dimensions.\n        input_channels : int\n                         Number of input channels.\n        bias           : bool\n                         Set to True to let convolutional layers learn a bias term.\n        normalization  : bool\n                         If True, adds a Batch Normalization layer after the convolutional layer.\n        activation     : torch.nn\n                         Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n        \"\"\"\n        super().__init__()\n        self.depth = depth\n        self.inc = convolution_layer(\n                                     input_channels = input_channels,\n                                     output_channels = dimensions,\n                                     kernel_size = kernel_size,\n                                     bias = bias,\n                                     normalization = normalization,\n                                     activation = activation\n                                    )\n        self.encoder = torch.nn.ModuleList()\n        for i in range(depth + 1):  # downsampling layers\n            if i == 0:\n                in_channels = dimensions * (2 ** i)\n                out_channels = dimensions * (2 ** i)\n            elif i == depth:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = dimensions * (2 ** (i - 1))\n            else:\n                in_channels = dimensions * (2 ** (i - 1))\n                out_channels = 2 * in_channels\n            pooling_layer = torch.nn.AvgPool2d(2)\n            double_convolution_layer = double_convolution(\n                                                          input_channels = in_channels,\n                                                          mid_channels = in_channels,\n                                                          output_channels = out_channels,\n                                                          kernel_size = kernel_size,\n                                                          bias = bias,\n                                                          normalization = normalization,\n                                                          activation = activation\n                                                         )\n            self.encoder.append(pooling_layer)\n            self.encoder.append(double_convolution_layer)\n        self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n        for i in range(depth, -1, -1):\n            if i == 1:\n                svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n            else:\n                svf_in_channels = 2 ** (self.depth + i) + 1\n            svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n            svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n            spatially_varying_kernel_generation = torch.nn.ModuleList()\n            for j in range(i, -1, -1):\n                pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n                spatially_varying_kernel_generation.append(pooling_layer)\n            kernel_generation_block = torch.nn.Sequential(\n                torch.nn.Conv2d(\n                                in_channels = svf_in_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_mid_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n                activation,\n                torch.nn.Conv2d(\n                                in_channels = svf_mid_channels,\n                                out_channels = svf_out_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               ),\n            )\n            spatially_varying_kernel_generation.append(kernel_generation_block)\n            self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n        self.decoder = torch.nn.ModuleList()\n        global_feature_layer = global_feature_module(  # global feature layer\n                                                     input_channels = dimensions * (2 ** (depth - 1)),\n                                                     mid_channels = dimensions * (2 ** (depth - 1)),\n                                                     output_channels = dimensions * (2 ** (depth - 1)),\n                                                     kernel_size = kernel_size,\n                                                     bias = bias,\n                                                     activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                    )\n        self.decoder.append(global_feature_layer)\n        for i in range(depth, 0, -1):\n            if i == 2:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels\n                up_mid_channels = up_in_channels\n            elif i == 1:\n                up_in_channels = dimensions * 2\n                up_out_channels = dimensions\n                up_mid_channels = up_out_channels\n            else:\n                up_in_channels = (dimensions // 2) * (2 ** i)\n                up_out_channels = up_in_channels // 2\n                up_mid_channels = up_in_channels\n            upsample_layer = upsample_convtranspose2d_layer(\n                                                            input_channels = up_in_channels,\n                                                            output_channels = up_mid_channels,\n                                                            kernel_size = 2,\n                                                            stride = 2,\n                                                            bias = bias,\n                                                           )\n            conv_layer = double_convolution(\n                                            input_channels = up_mid_channels,\n                                            output_channels = up_out_channels,\n                                            kernel_size = kernel_size,\n                                            bias = bias,\n                                            normalization = normalization,\n                                            activation = activation,\n                                           )\n            self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n\n\n    def forward(self, focal_surface, field):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        focal_surface : torch.tensor\n                        Input focal surface data.\n                        Dimension: (1, 1, H, W)\n\n        field         : torch.tensor\n                        Input field data.\n                        Dimension: (1, 6, H, W)\n\n        Returns\n        -------\n        sv_kernel : list of torch.tensor\n                    Learned spatially varying kernels.\n                    Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                    where C_i, H_i, and W_i represent the channel, height, and width\n                    of each feature at a certain scale.\n        \"\"\"\n        x = self.inc(torch.cat((focal_surface, field), dim = 1))\n        downsampling_outputs = [focal_surface]\n        downsampling_outputs.append(x)\n        for i, down_layer in enumerate(self.encoder):\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        sv_kernels = []\n        for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n            if i == 0:\n                global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n                downsampling_outputs[-1] = global_feature\n                sv_feature = [global_feature, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                              sv_feature[3]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n            else:\n                x_up = up_layer[0](downsampling_outputs[-1],\n                                   downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n                x_up = up_layer[1](x_up)\n                downsampling_outputs[-1] = x_up\n                sv_feature = [x_up, downsampling_outputs[0]]\n                for j in range(self.depth - i + 1):\n                    sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                    if j > 0:\n                        sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n                if i == 1:\n                    sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n                sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n                sv_kernels.append(sv_kernel)\n        return sv_kernels\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_varying_kernel_generation_model.__init__","title":"__init__(depth=3, dimensions=8, input_channels=7, kernel_size=3, bias=True, normalization=False, activation=torch.nn.LeakyReLU(0.2, inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
             Number of upsampling and downsampling layers.\n
  • dimensions \u2013
             Number of dimensions.\n
  • input_channels (int, default: 7 ) \u2013
             Number of input channels.\n
  • bias \u2013
             Set to True to let convolutional layers learn a bias term.\n
  • normalization \u2013
             If True, adds a Batch Normalization layer after the convolutional layer.\n
  • activation \u2013
             Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n
Source code in odak/learn/models/models.py
def __init__(\n             self,\n             depth = 3,\n             dimensions = 8,\n             input_channels = 7,\n             kernel_size = 3,\n             bias = True,\n             normalization = False,\n             activation = torch.nn.LeakyReLU(0.2, inplace = True)\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth          : int\n                     Number of upsampling and downsampling layers.\n    dimensions     : int\n                     Number of dimensions.\n    input_channels : int\n                     Number of input channels.\n    bias           : bool\n                     Set to True to let convolutional layers learn a bias term.\n    normalization  : bool\n                     If True, adds a Batch Normalization layer after the convolutional layer.\n    activation     : torch.nn\n                     Non-linear activation layer (e.g., torch.nn.ReLU(), torch.nn.Sigmoid()).\n    \"\"\"\n    super().__init__()\n    self.depth = depth\n    self.inc = convolution_layer(\n                                 input_channels = input_channels,\n                                 output_channels = dimensions,\n                                 kernel_size = kernel_size,\n                                 bias = bias,\n                                 normalization = normalization,\n                                 activation = activation\n                                )\n    self.encoder = torch.nn.ModuleList()\n    for i in range(depth + 1):  # downsampling layers\n        if i == 0:\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** i)\n        elif i == depth:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = dimensions * (2 ** (i - 1))\n        else:\n            in_channels = dimensions * (2 ** (i - 1))\n            out_channels = 2 * in_channels\n        pooling_layer = torch.nn.AvgPool2d(2)\n        double_convolution_layer = double_convolution(\n                                                      input_channels = in_channels,\n                                                      mid_channels = in_channels,\n                                                      output_channels = out_channels,\n                                                      kernel_size = kernel_size,\n                                                      bias = bias,\n                                                      normalization = normalization,\n                                                      activation = activation\n                                                     )\n        self.encoder.append(pooling_layer)\n        self.encoder.append(double_convolution_layer)\n    self.spatially_varying_feature = torch.nn.ModuleList()  # for kernel generation\n    for i in range(depth, -1, -1):\n        if i == 1:\n            svf_in_channels = dimensions + 2 ** (self.depth + i) + 1\n        else:\n            svf_in_channels = 2 ** (self.depth + i) + 1\n        svf_out_channels = (2 ** (self.depth + i)) * (kernel_size * kernel_size)\n        svf_mid_channels = dimensions * (2 ** (self.depth - 1))\n        spatially_varying_kernel_generation = torch.nn.ModuleList()\n        for j in range(i, -1, -1):\n            pooling_layer = torch.nn.AvgPool2d(2 ** (j + 1))\n            spatially_varying_kernel_generation.append(pooling_layer)\n        kernel_generation_block = torch.nn.Sequential(\n            torch.nn.Conv2d(\n                            in_channels = svf_in_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_mid_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n            activation,\n            torch.nn.Conv2d(\n                            in_channels = svf_mid_channels,\n                            out_channels = svf_out_channels,\n                            kernel_size = kernel_size,\n                            padding = kernel_size // 2,\n                            bias = bias\n                           ),\n        )\n        spatially_varying_kernel_generation.append(kernel_generation_block)\n        self.spatially_varying_feature.append(spatially_varying_kernel_generation)\n    self.decoder = torch.nn.ModuleList()\n    global_feature_layer = global_feature_module(  # global feature layer\n                                                 input_channels = dimensions * (2 ** (depth - 1)),\n                                                 mid_channels = dimensions * (2 ** (depth - 1)),\n                                                 output_channels = dimensions * (2 ** (depth - 1)),\n                                                 kernel_size = kernel_size,\n                                                 bias = bias,\n                                                 activation = torch.nn.LeakyReLU(0.2, inplace = True)\n                                                )\n    self.decoder.append(global_feature_layer)\n    for i in range(depth, 0, -1):\n        if i == 2:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels\n            up_mid_channels = up_in_channels\n        elif i == 1:\n            up_in_channels = dimensions * 2\n            up_out_channels = dimensions\n            up_mid_channels = up_out_channels\n        else:\n            up_in_channels = (dimensions // 2) * (2 ** i)\n            up_out_channels = up_in_channels // 2\n            up_mid_channels = up_in_channels\n        upsample_layer = upsample_convtranspose2d_layer(\n                                                        input_channels = up_in_channels,\n                                                        output_channels = up_mid_channels,\n                                                        kernel_size = 2,\n                                                        stride = 2,\n                                                        bias = bias,\n                                                       )\n        conv_layer = double_convolution(\n                                        input_channels = up_mid_channels,\n                                        output_channels = up_out_channels,\n                                        kernel_size = kernel_size,\n                                        bias = bias,\n                                        normalization = normalization,\n                                        activation = activation,\n                                       )\n        self.decoder.append(torch.nn.ModuleList([upsample_layer, conv_layer]))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.spatially_varying_kernel_generation_model.forward","title":"forward(focal_surface, field)","text":"

Forward model.

Parameters:

  • focal_surface (tensor) \u2013
            Input focal surface data.\n        Dimension: (1, 1, H, W)\n
  • field \u2013
            Input field data.\n        Dimension: (1, 6, H, W)\n

Returns:

  • sv_kernel ( list of torch.tensor ) \u2013

    Learned spatially varying kernels. Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i), where C_i, H_i, and W_i represent the channel, height, and width of each feature at a certain scale.

Source code in odak/learn/models/models.py
def forward(self, focal_surface, field):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    focal_surface : torch.tensor\n                    Input focal surface data.\n                    Dimension: (1, 1, H, W)\n\n    field         : torch.tensor\n                    Input field data.\n                    Dimension: (1, 6, H, W)\n\n    Returns\n    -------\n    sv_kernel : list of torch.tensor\n                Learned spatially varying kernels.\n                Dimension of each element in the list: (1, C_i * kernel_size * kernel_size, H_i, W_i),\n                where C_i, H_i, and W_i represent the channel, height, and width\n                of each feature at a certain scale.\n    \"\"\"\n    x = self.inc(torch.cat((focal_surface, field), dim = 1))\n    downsampling_outputs = [focal_surface]\n    downsampling_outputs.append(x)\n    for i, down_layer in enumerate(self.encoder):\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    sv_kernels = []\n    for i, (up_layer, svf_layer) in enumerate(zip(self.decoder, self.spatially_varying_feature)):\n        if i == 0:\n            global_feature = up_layer(downsampling_outputs[-2], downsampling_outputs[-1])\n            downsampling_outputs[-1] = global_feature\n            sv_feature = [global_feature, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            sv_feature = [sv_feature[0], sv_feature[1], sv_feature[4], sv_feature[2],\n                          sv_feature[3]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n        else:\n            x_up = up_layer[0](downsampling_outputs[-1],\n                               downsampling_outputs[2 * (self.depth + 1 - i) + 1])\n            x_up = up_layer[1](x_up)\n            downsampling_outputs[-1] = x_up\n            sv_feature = [x_up, downsampling_outputs[0]]\n            for j in range(self.depth - i + 1):\n                sv_feature[1] = svf_layer[self.depth - i](sv_feature[1])\n                if j > 0:\n                    sv_feature.append(svf_layer[j](downsampling_outputs[2 * j]))\n            if i == 1:\n                sv_feature = [sv_feature[0], sv_feature[1], sv_feature[3], sv_feature[2]]\n            sv_kernel = svf_layer[-1](torch.cat(sv_feature, dim = 1))\n            sv_kernels.append(sv_kernel)\n    return sv_kernels\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.unet","title":"unet","text":"

Bases: Module

A U-Net model, heavily inspired from https://github.com/milesial/Pytorch-UNet/tree/master/unet and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.

Source code in odak/learn/models/models.py
class unet(torch.nn.Module):\n    \"\"\"\n    A U-Net model, heavily inspired from `https://github.com/milesial/Pytorch-UNet/tree/master/unet` and more can be read from Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. \"U-net: Convolutional networks for biomedical image segmentation.\" Medical Image Computing and Computer-Assisted Intervention\u2013MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015.\n    \"\"\"\n\n    def __init__(\n                 self, \n                 depth = 4,\n                 dimensions = 64, \n                 input_channels = 2, \n                 output_channels = 1, \n                 bilinear = False,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(inplace = True),\n                ):\n        \"\"\"\n        U-Net model.\n\n        Parameters\n        ----------\n        depth             : int\n                            Number of upsampling and downsampling\n        dimensions        : int\n                            Number of dimensions.\n        input_channels    : int\n                            Number of input channels.\n        output_channels   : int\n                            Number of output channels.\n        bilinear          : bool\n                            Uses bilinear upsampling in upsampling layers when set True.\n        bias              : bool\n                            Set True to let convolutional layers learn a bias term.\n        activation        : torch.nn\n                            Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n        \"\"\"\n        super(unet, self).__init__()\n        self.inc = double_convolution(\n                                      input_channels = input_channels,\n                                      mid_channels = dimensions,\n                                      output_channels = dimensions,\n                                      kernel_size = kernel_size,\n                                      bias = bias,\n                                      activation = activation\n                                     )      \n\n        self.downsampling_layers = torch.nn.ModuleList()\n        self.upsampling_layers = torch.nn.ModuleList()\n        for i in range(depth): # downsampling layers\n            in_channels = dimensions * (2 ** i)\n            out_channels = dimensions * (2 ** (i + 1))\n            down_layer = downsample_layer(in_channels,\n                                            out_channels,\n                                            kernel_size=kernel_size,\n                                            bias=bias,\n                                            activation=activation\n                                            )\n            self.downsampling_layers.append(down_layer)      \n\n        for i in range(depth - 1, -1, -1):  # upsampling layers\n            up_in_channels = dimensions * (2 ** (i + 1))  \n            up_out_channels = dimensions * (2 ** i) \n            up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n            self.upsampling_layers.append(up_layer)\n        self.outc = torch.nn.Conv2d(\n                                    dimensions, \n                                    output_channels,\n                                    kernel_size = kernel_size,\n                                    padding = kernel_size // 2,\n                                    bias = bias\n                                   )\n\n\n    def forward(self, x):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x             : torch.tensor\n                        Input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Estimated output.      \n        \"\"\"\n        downsampling_outputs = [self.inc(x)]\n        for down_layer in self.downsampling_layers:\n            x_down = down_layer(downsampling_outputs[-1])\n            downsampling_outputs.append(x_down)\n        x_up = downsampling_outputs[-1]\n        for i, up_layer in enumerate((self.upsampling_layers)):\n            x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n        result = self.outc(x_up)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.unet.__init__","title":"__init__(depth=4, dimensions=64, input_channels=2, output_channels=1, bilinear=False, kernel_size=3, bias=False, activation=torch.nn.ReLU(inplace=True))","text":"

U-Net model.

Parameters:

  • depth \u2013
                Number of upsampling and downsampling\n
  • dimensions \u2013
                Number of dimensions.\n
  • input_channels \u2013
                Number of input channels.\n
  • output_channels \u2013
                Number of output channels.\n
  • bilinear \u2013
                Uses bilinear upsampling in upsampling layers when set True.\n
  • bias \u2013
                Set True to let convolutional layers learn a bias term.\n
  • activation \u2013
                Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n
Source code in odak/learn/models/models.py
def __init__(\n             self, \n             depth = 4,\n             dimensions = 64, \n             input_channels = 2, \n             output_channels = 1, \n             bilinear = False,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(inplace = True),\n            ):\n    \"\"\"\n    U-Net model.\n\n    Parameters\n    ----------\n    depth             : int\n                        Number of upsampling and downsampling\n    dimensions        : int\n                        Number of dimensions.\n    input_channels    : int\n                        Number of input channels.\n    output_channels   : int\n                        Number of output channels.\n    bilinear          : bool\n                        Uses bilinear upsampling in upsampling layers when set True.\n    bias              : bool\n                        Set True to let convolutional layers learn a bias term.\n    activation        : torch.nn\n                        Non-linear activation layer to be used (e.g., torch.nn.ReLU(), torch.nn.Sigmoid().\n    \"\"\"\n    super(unet, self).__init__()\n    self.inc = double_convolution(\n                                  input_channels = input_channels,\n                                  mid_channels = dimensions,\n                                  output_channels = dimensions,\n                                  kernel_size = kernel_size,\n                                  bias = bias,\n                                  activation = activation\n                                 )      \n\n    self.downsampling_layers = torch.nn.ModuleList()\n    self.upsampling_layers = torch.nn.ModuleList()\n    for i in range(depth): # downsampling layers\n        in_channels = dimensions * (2 ** i)\n        out_channels = dimensions * (2 ** (i + 1))\n        down_layer = downsample_layer(in_channels,\n                                        out_channels,\n                                        kernel_size=kernel_size,\n                                        bias=bias,\n                                        activation=activation\n                                        )\n        self.downsampling_layers.append(down_layer)      \n\n    for i in range(depth - 1, -1, -1):  # upsampling layers\n        up_in_channels = dimensions * (2 ** (i + 1))  \n        up_out_channels = dimensions * (2 ** i) \n        up_layer = upsample_layer(up_in_channels, up_out_channels, kernel_size=kernel_size, bias=bias, activation=activation, bilinear=bilinear)\n        self.upsampling_layers.append(up_layer)\n    self.outc = torch.nn.Conv2d(\n                                dimensions, \n                                output_channels,\n                                kernel_size = kernel_size,\n                                padding = kernel_size // 2,\n                                bias = bias\n                               )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.unet.forward","title":"forward(x)","text":"

Forward model.

Parameters:

  • x \u2013
            Input data.\n

Returns:

  • result ( tensor ) \u2013

    Estimated output.

Source code in odak/learn/models/models.py
def forward(self, x):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x             : torch.tensor\n                    Input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Estimated output.      \n    \"\"\"\n    downsampling_outputs = [self.inc(x)]\n    for down_layer in self.downsampling_layers:\n        x_down = down_layer(downsampling_outputs[-1])\n        downsampling_outputs.append(x_down)\n    x_up = downsampling_outputs[-1]\n    for i, up_layer in enumerate((self.upsampling_layers)):\n        x_up = up_layer(x_up, downsampling_outputs[-(i + 2)])       \n    result = self.outc(x_up)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_convtranspose2d_layer","title":"upsample_convtranspose2d_layer","text":"

Bases: Module

An upsampling convtranspose2d layer.

Source code in odak/learn/models/components.py
class upsample_convtranspose2d_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convtranspose2d layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 2,\n                 stride = 2,\n                 bias = False,\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool\n                          Set to True to let convolutional layers have bias term.\n        \"\"\"\n        super().__init__()\n        self.up = torch.nn.ConvTranspose2d(\n                                           in_channels = input_channels,\n                                           out_channels = output_channels,\n                                           bias = bias,\n                                           kernel_size = kernel_size,\n                                           stride = stride\n                                          )\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\"\n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        result = x1 + x2\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_convtranspose2d_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=2, stride=2, bias=False)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 2,\n             stride = 2,\n             bias = False,\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool\n                      Set to True to let convolutional layers have bias term.\n    \"\"\"\n    super().__init__()\n    self.up = torch.nn.ConvTranspose2d(\n                                       in_channels = input_channels,\n                                       out_channels = output_channels,\n                                       bias = bias,\n                                       kernel_size = kernel_size,\n                                       stride = stride\n                                      )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_convtranspose2d_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\"\n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    result = x1 + x2\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_layer","title":"upsample_layer","text":"

Bases: Module

An upsampling convolutional layer.

Source code in odak/learn/models/components.py
class upsample_layer(torch.nn.Module):\n    \"\"\"\n    An upsampling convolutional layer.\n    \"\"\"\n    def __init__(\n                 self,\n                 input_channels,\n                 output_channels,\n                 kernel_size = 3,\n                 bias = False,\n                 activation = torch.nn.ReLU(),\n                 bilinear = True\n                ):\n        \"\"\"\n        A downscaling component with a double convolution.\n\n        Parameters\n        ----------\n        input_channels  : int\n                          Number of input channels.\n        output_channels : int\n                          Number of output channels.\n        kernel_size     : int\n                          Kernel size.\n        bias            : bool \n                          Set to True to let convolutional layers have bias term.\n        activation      : torch.nn\n                          Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n        bilinear        : bool\n                          If set to True, bilinear sampling is used.\n        \"\"\"\n        super(upsample_layer, self).__init__()\n        if bilinear:\n            self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n            self.conv = double_convolution(\n                                           input_channels = input_channels + output_channels,\n                                           mid_channels = input_channels // 2,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n        else:\n            self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n            self.conv = double_convolution(\n                                           input_channels = input_channels,\n                                           mid_channels = output_channels,\n                                           output_channels = output_channels,\n                                           kernel_size = kernel_size,\n                                           bias = bias,\n                                           activation = activation\n                                          )\n\n\n    def forward(self, x1, x2):\n        \"\"\"\n        Forward model.\n\n        Parameters\n        ----------\n        x1             : torch.tensor\n                         First input data.\n        x2             : torch.tensor\n                         Second input data.\n\n\n        Returns\n        ----------\n        result        : torch.tensor\n                        Result of the forward operation\n        \"\"\" \n        x1 = self.up(x1)\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n        x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                          diffY // 2, diffY - diffY // 2])\n        x = torch.cat([x2, x1], dim = 1)\n        result = self.conv(x)\n        return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_layer.__init__","title":"__init__(input_channels, output_channels, kernel_size=3, bias=False, activation=torch.nn.ReLU(), bilinear=True)","text":"

A downscaling component with a double convolution.

Parameters:

  • input_channels \u2013
              Number of input channels.\n
  • output_channels (int) \u2013
              Number of output channels.\n
  • kernel_size \u2013
              Kernel size.\n
  • bias \u2013
              Set to True to let convolutional layers have bias term.\n
  • activation \u2013
              Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n
  • bilinear \u2013
              If set to True, bilinear sampling is used.\n
Source code in odak/learn/models/components.py
def __init__(\n             self,\n             input_channels,\n             output_channels,\n             kernel_size = 3,\n             bias = False,\n             activation = torch.nn.ReLU(),\n             bilinear = True\n            ):\n    \"\"\"\n    A downscaling component with a double convolution.\n\n    Parameters\n    ----------\n    input_channels  : int\n                      Number of input channels.\n    output_channels : int\n                      Number of output channels.\n    kernel_size     : int\n                      Kernel size.\n    bias            : bool \n                      Set to True to let convolutional layers have bias term.\n    activation      : torch.nn\n                      Nonlinear activation layer to be used. If None, uses torch.nn.ReLU().\n    bilinear        : bool\n                      If set to True, bilinear sampling is used.\n    \"\"\"\n    super(upsample_layer, self).__init__()\n    if bilinear:\n        self.up = torch.nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = True)\n        self.conv = double_convolution(\n                                       input_channels = input_channels + output_channels,\n                                       mid_channels = input_channels // 2,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n    else:\n        self.up = torch.nn.ConvTranspose2d(input_channels , input_channels // 2, kernel_size = 2, stride = 2)\n        self.conv = double_convolution(\n                                       input_channels = input_channels,\n                                       mid_channels = output_channels,\n                                       output_channels = output_channels,\n                                       kernel_size = kernel_size,\n                                       bias = bias,\n                                       activation = activation\n                                      )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.upsample_layer.forward","title":"forward(x1, x2)","text":"

Forward model.

Parameters:

  • x1 \u2013
             First input data.\n
  • x2 \u2013
             Second input data.\n

Returns:

  • result ( tensor ) \u2013

    Result of the forward operation

Source code in odak/learn/models/components.py
def forward(self, x1, x2):\n    \"\"\"\n    Forward model.\n\n    Parameters\n    ----------\n    x1             : torch.tensor\n                     First input data.\n    x2             : torch.tensor\n                     Second input data.\n\n\n    Returns\n    ----------\n    result        : torch.tensor\n                    Result of the forward operation\n    \"\"\" \n    x1 = self.up(x1)\n    diffY = x2.size()[2] - x1.size()[2]\n    diffX = x2.size()[3] - x1.size()[3]\n    x1 = torch.nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,\n                                      diffY // 2, diffY - diffY // 2])\n    x = torch.cat([x2, x1], dim = 1)\n    result = self.conv(x)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.gaussian","title":"gaussian(x, multiplier=1.0)","text":"

A Gaussian non-linear activation. For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.

Parameters:

  • x \u2013
           Input data.\n
  • multiplier \u2013
           Multiplier.\n

Returns:

  • result ( float or tensor ) \u2013

    Ouput data.

Source code in odak/learn/models/components.py
def gaussian(x, multiplier = 1.):\n    \"\"\"\n    A Gaussian non-linear activation.\n    For more details: Ramasinghe, Sameera, and Simon Lucey. \"Beyond periodicity: Towards a unifying framework for activations in coordinate-mlps.\" In European Conference on Computer Vision, pp. 142-158. Cham: Springer Nature Switzerland, 2022.\n\n    Parameters\n    ----------\n    x            : float or torch.tensor\n                   Input data.\n    multiplier   : float or torch.tensor\n                   Multiplier.\n\n    Returns\n    -------\n    result       : float or torch.tensor\n                   Ouput data.\n    \"\"\"\n    result = torch.exp(- (multiplier * x) ** 2)\n    return result\n
"},{"location":"odak/learn_wave/#odak.learn.wave.models.swish","title":"swish(x)","text":"

A swish non-linear activation. For more details: https://en.wikipedia.org/wiki/Swish_function

Parameters:

  • x \u2013
             Input.\n

Returns:

  • out ( float or tensor ) \u2013

    Output.

Source code in odak/learn/models/components.py
def swish(x):\n    \"\"\"\n    A swish non-linear activation.\n    For more details: https://en.wikipedia.org/wiki/Swish_function\n\n    Parameters\n    -----------\n    x              : float or torch.tensor\n                     Input.\n\n    Returns\n    -------\n    out            : float or torch.tensor\n                     Output.\n    \"\"\"\n    out = x * torch.sigmoid(x)\n    return out\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer","title":"multi_color_hologram_optimizer","text":"

A class for optimizing single or multi color holograms. For more details, see Kavakl\u0131 et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.

Source code in odak/learn/wave/optimizers.py
class multi_color_hologram_optimizer():\n    \"\"\"\n    A class for optimizing single or multi color holograms.\n    For more details, see Kavakl\u0131 et al., SIGGRAPH ASIA 2023, Multi-color Holograms Improve Brightness in HOlographic Displays.\n    \"\"\"\n    def __init__(self,\n                 wavelengths,\n                 resolution,\n                 targets,\n                 propagator,\n                 number_of_frames = 3,\n                 number_of_depth_layers = 1,\n                 learning_rate = 2e-2,\n                 learning_rate_floor = 5e-3,\n                 double_phase = True,\n                 scale_factor = 1,\n                 method = 'multi-color',\n                 channel_power_filename = '',\n                 device = None,\n                 loss_function = None,\n                 peak_amplitude = 1.0,\n                 optimize_peak_amplitude = False,\n                 img_loss_thres = 2e-3,\n                 reduction = 'sum'\n                ):\n        self.device = device\n        if isinstance(self.device, type(None)):\n            self.device = torch.device(\"cpu\")\n        torch.cuda.empty_cache()\n        torch.random.seed()\n        self.wavelengths = wavelengths\n        self.resolution = resolution\n        self.targets = targets\n        if propagator.propagation_type != 'Impulse Response Fresnel':\n            scale_factor = 1\n        self.scale_factor = scale_factor\n        self.propagator = propagator\n        self.learning_rate = learning_rate\n        self.learning_rate_floor = learning_rate_floor\n        self.number_of_channels = len(self.wavelengths)\n        self.number_of_frames = number_of_frames\n        self.number_of_depth_layers = number_of_depth_layers\n        self.double_phase = double_phase\n        self.channel_power_filename = channel_power_filename\n        self.method = method\n        if self.method != 'conventional' and self.method != 'multi-color':\n           logging.warning('Unknown optimization method. Options are conventional or multi-color.')\n           import sys\n           sys.exit()\n        self.peak_amplitude = peak_amplitude\n        self.optimize_peak_amplitude = optimize_peak_amplitude\n        if self.optimize_peak_amplitude:\n            self.init_peak_amplitude_scale()\n        self.img_loss_thres = img_loss_thres\n        self.kernels = []\n        self.init_phase()\n        self.init_channel_power()\n        self.init_loss_function(loss_function, reduction = reduction)\n        self.init_amplitude()\n        self.init_phase_scale()\n\n\n    def init_peak_amplitude_scale(self):\n        \"\"\"\n        Internal function to set the phase scale.\n        \"\"\"\n        self.peak_amplitude = torch.tensor(\n                                           self.peak_amplitude,\n                                           requires_grad = True,\n                                           device=self.device\n                                          )\n\n\n    def init_phase_scale(self):\n        \"\"\"\n        Internal function to set the phase scale.\n        \"\"\"\n        if self.method == 'conventional':\n            self.phase_scale = torch.tensor(\n                                            [\n                                             1.,\n                                             1.,\n                                             1.\n                                            ],\n                                            requires_grad = False,\n                                            device = self.device\n                                           )\n        if self.method == 'multi-color':\n            self.phase_scale = torch.tensor(\n                                            [\n                                             1.,\n                                             1.,\n                                             1.\n                                            ],\n                                            requires_grad = False,\n                                            device = self.device\n                                           )\n\n\n    def init_amplitude(self):\n        \"\"\"\n        Internal function to set the amplitude of the illumination source.\n        \"\"\"\n        self.amplitude = torch.zeros(\n                                     self.resolution[0] * self.scale_factor,\n                                     self.resolution[1] * self.scale_factor,\n                                     requires_grad = False,\n                                     device = self.device\n                                    )\n        self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.\n\n\n    def init_phase(self):\n        \"\"\"\n        Internal function to set the starting phase of the phase-only hologram.\n        \"\"\"\n        self.phase = torch.zeros(\n                                 self.number_of_frames,\n                                 self.resolution[0],\n                                 self.resolution[1],\n                                 device = self.device,\n                                 requires_grad = True\n                                )\n        self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)\n\n\n    def init_channel_power(self):\n        \"\"\"\n        Internal function to set the starting phase of the phase-only hologram.\n        \"\"\"\n        if self.method == 'conventional':\n            logging.warning('Scheme: Conventional')\n            self.channel_power = torch.eye(\n                                           self.number_of_frames,\n                                           self.number_of_channels,\n                                           device = self.device,\n                                           requires_grad = False\n                                          )\n\n        elif self.method == 'multi-color':\n            logging.warning('Scheme: Multi-color')\n            self.channel_power = torch.ones(\n                                            self.number_of_frames,\n                                            self.number_of_channels,\n                                            device = self.device,\n                                            requires_grad = True\n                                           )\n        if self.channel_power_filename != '':\n            self.channel_power = torch_load(self.channel_power_filename).to(self.device)\n            self.channel_power.requires_grad = False\n            self.channel_power[self.channel_power < 0.] = 0.\n            self.channel_power[self.channel_power > 1.] = 1.\n            if self.method == 'multi-color':\n                self.channel_power.requires_grad = True\n            if self.method == 'conventional':\n                self.channel_power = torch.abs(torch.cos(self.channel_power))\n            logging.warning('Channel powers:')\n            logging.warning(self.channel_power)\n            logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))\n        self.propagator.set_laser_powers(self.channel_power)\n\n\n\n    def init_optimizer(self):\n        \"\"\"\n        Internal function to set the optimizer.\n        \"\"\"\n        optimization_variables = [self.phase, self.offset]\n        if self.optimize_peak_amplitude:\n            optimization_variables.append(self.peak_amplitude)\n        if self.method == 'multi-color':\n            optimization_variables.append(self.propagator.channel_power)\n        self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)\n\n\n    def init_loss_function(self, loss_function, reduction = 'sum'):\n        \"\"\"\n        Internal function to set the loss function.\n        \"\"\"\n        self.l2_loss = torch.nn.MSELoss(reduction = reduction)\n        self.loss_type = 'custom'\n        self.loss_function = loss_function\n        if isinstance(self.loss_function, type(None)):\n            self.loss_type = 'conventional'\n            self.loss_function = torch.nn.MSELoss(reduction = reduction)\n\n\n\n    def evaluate(self, input_image, target_image, plane_id = 0):\n        \"\"\"\n        Internal function to evaluate the loss.\n        \"\"\"\n        if self.loss_type == 'conventional':\n            loss = self.loss_function(input_image, target_image)\n        elif self.loss_type == 'custom':\n            loss = 0\n            for i in range(len(self.wavelengths)):\n                loss += self.loss_function(\n                                           input_image[i],\n                                           target_image[i],\n                                           plane_id = plane_id\n                                          )\n        return loss\n\n\n    def double_phase_constrain(self, phase, phase_offset):\n        \"\"\"\n        Internal function to constrain a given phase similarly to double phase encoding.\n\n        Parameters\n        ----------\n        phase                      : torch.tensor\n                                     Input phase values to be constrained.\n        phase_offset               : torch.tensor\n                                     Input phase offset value.\n\n        Returns\n        -------\n        phase_only                 : torch.tensor\n                                     Constrained output phase.\n        \"\"\"\n        phase_zero_mean = phase - torch.mean(phase)\n        phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)\n        phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)\n        loss = multi_scale_total_variation_loss(phase_low, levels = 6)\n        loss += multi_scale_total_variation_loss(phase_high, levels = 6)\n        loss += torch.std(phase_low)\n        loss += torch.std(phase_high)\n        phase_only = torch.zeros_like(phase)\n        phase_only[0::2, 0::2] = phase_low[0::2, 0::2]\n        phase_only[0::2, 1::2] = phase_high[0::2, 1::2]\n        phase_only[1::2, 0::2] = phase_high[1::2, 0::2]\n        phase_only[1::2, 1::2] = phase_low[1::2, 1::2]\n        return phase_only, loss\n\n\n    def direct_phase_constrain(self, phase, phase_offset):\n        \"\"\"\n        Internal function to constrain a given phase.\n\n        Parameters\n        ----------\n        phase                      : torch.tensor\n                                     Input phase values to be constrained.\n        phase_offset               : torch.tensor\n                                     Input phase offset value.\n\n        Returns\n        -------\n        phase_only                 : torch.tensor\n                                     Constrained output phase.\n        \"\"\"\n        phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)\n        loss = multi_scale_total_variation_loss(phase, levels = 6)\n        loss += multi_scale_total_variation_loss(phase_offset, levels = 6)\n        return phase_only, loss\n\n\n    def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):\n        \"\"\"\n        Function to optimize multiplane phase-only holograms using stochastic gradient descent.\n\n        Parameters\n        ----------\n        number_of_iterations       : float\n                                     Number of iterations.\n        weights                    : list\n                                     Weights used in the loss function.\n\n        Returns\n        -------\n        hologram                   : torch.tensor\n                                     Optimised hologram.\n        \"\"\"\n        hologram_phases = torch.zeros(\n                                      self.number_of_frames,\n                                      self.resolution[0],\n                                      self.resolution[1],\n                                      device = self.device\n                                     )\n        t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)\n        if self.optimize_peak_amplitude:\n            peak_amp_cache = self.peak_amplitude.item()\n        for step in t:\n            for g in self.optimizer.param_groups:\n                g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations\n                if g['lr'] < self.learning_rate_floor:\n                    g['lr'] = self.learning_rate_floor\n                learning_rate = g['lr']\n            total_loss = 0\n            t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)\n            for depth_id in t_depth:\n                self.optimizer.zero_grad()\n                depth_target = self.targets[depth_id]\n                reconstruction_intensities = torch.zeros(\n                                                         self.number_of_frames,\n                                                         self.number_of_channels,\n                                                         self.resolution[0] * self.scale_factor,\n                                                         self.resolution[1] * self.scale_factor,\n                                                         device = self.device\n                                                        )\n                loss_variation_hologram = 0\n                laser_powers = self.propagator.get_laser_powers()\n                for frame_id in range(self.number_of_frames):\n                    if self.double_phase:\n                        phase, loss_phase = self.double_phase_constrain(\n                                                                        self.phase[frame_id],\n                                                                        self.offset[frame_id]\n                                                                       )\n                    else:\n                        phase, loss_phase = self.direct_phase_constrain(\n                                                                        self.phase[frame_id],\n                                                                        self.offset[frame_id]\n                                                                       )\n                    loss_variation_hologram += loss_phase\n                    for channel_id in range(self.number_of_channels):\n                        phase_scaled = torch.zeros_like(self.amplitude)\n                        phase_scaled[::self.scale_factor, ::self.scale_factor] = phase\n                        laser_power = laser_powers[frame_id][channel_id]\n                        hologram = generate_complex_field(\n                                                          laser_power * self.amplitude,\n                                                          phase_scaled * self.phase_scale[channel_id]\n                                                         )\n                        reconstruction_field = self.propagator(hologram, channel_id, depth_id)\n                        intensity = calculate_amplitude(reconstruction_field) ** 2\n                        reconstruction_intensities[frame_id, channel_id] += intensity\n                    hologram_phases[frame_id] = phase.detach().clone()\n                loss_laser = self.l2_loss(\n                                          torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,\n                                          torch.sum(laser_powers, dim = 0)\n                                         )\n                loss_laser += self.l2_loss(\n                                           torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),\n                                           torch.sum(laser_powers).view(1,)\n                                          )\n                loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))\n                reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)\n                loss_image = self.evaluate(\n                                           reconstruction_intensity,\n                                           depth_target * self.peak_amplitude,\n                                           plane_id = depth_id\n                                          )\n                loss = weights[0] * loss_image\n                loss += weights[1] * loss_laser\n                loss += weights[2] * loss_variation_hologram\n                include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres\n                if include_pa_loss_flag:\n                    loss -= self.peak_amplitude * 1.\n                if self.method == 'conventional':\n                    loss.backward()\n                else:\n                    loss.backward(retain_graph = True)\n                self.optimizer.step()\n                if include_pa_loss_flag:\n                    peak_amp_cache = self.peak_amplitude.item()\n                else:\n                    with torch.no_grad():\n                        if self.optimize_peak_amplitude:\n                            self.peak_amplitude.view([1])[0] = peak_amp_cache\n                total_loss += loss.detach().item()\n                loss_image = loss_image.detach()\n                del loss_laser\n                del loss_variation_hologram\n                del loss\n            description = \"Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}\".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)\n            t.set_description(description)\n            del total_loss\n            del loss_image\n            del reconstruction_field\n            del reconstruction_intensities\n            del intensity\n            del phase\n            del hologram\n        logging.warning(description)\n        return hologram_phases.detach()\n\n\n    def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):\n        \"\"\"\n        Function to optimize multiplane phase-only holograms.\n\n        Parameters\n        ----------\n        number_of_iterations       : int\n                                     Number of iterations.\n        weights                    : list\n                                     Loss weights.\n        bits                       : int\n                                     Quantizes the hologram using the given bits and reconstructs.\n\n        Returns\n        -------\n        hologram_phases            : torch.tensor\n                                     Phases of the optimized phase-only hologram.\n        reconstruction_intensities : torch.tensor\n                                     Intensities of the images reconstructed at each plane with the optimized phase-only hologram.\n        \"\"\"\n        self.init_optimizer()\n        hologram_phases = self.gradient_descent(\n                                                number_of_iterations=number_of_iterations,\n                                                weights=weights\n                                               )\n        hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi\n        torch.no_grad()\n        reconstruction_intensities = self.propagator.reconstruct(hologram_phases)\n        laser_powers = self.propagator.get_laser_powers()\n        channel_powers = self.propagator.channel_power\n        logging.warning(\"Final peak amplitude: {}\".format(self.peak_amplitude))\n        logging.warning('Laser powers: {}'.format(laser_powers))\n        return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.direct_phase_constrain","title":"direct_phase_constrain(phase, phase_offset)","text":"

Internal function to constrain a given phase.

Parameters:

  • phase \u2013
                         Input phase values to be constrained.\n
  • phase_offset \u2013
                         Input phase offset value.\n

Returns:

  • phase_only ( tensor ) \u2013

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def direct_phase_constrain(self, phase, phase_offset):\n    \"\"\"\n    Internal function to constrain a given phase.\n\n    Parameters\n    ----------\n    phase                      : torch.tensor\n                                 Input phase values to be constrained.\n    phase_offset               : torch.tensor\n                                 Input phase offset value.\n\n    Returns\n    -------\n    phase_only                 : torch.tensor\n                                 Constrained output phase.\n    \"\"\"\n    phase_only = torch.nan_to_num(phase - phase_offset, nan = 2 * np.pi)\n    loss = multi_scale_total_variation_loss(phase, levels = 6)\n    loss += multi_scale_total_variation_loss(phase_offset, levels = 6)\n    return phase_only, loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.double_phase_constrain","title":"double_phase_constrain(phase, phase_offset)","text":"

Internal function to constrain a given phase similarly to double phase encoding.

Parameters:

  • phase \u2013
                         Input phase values to be constrained.\n
  • phase_offset \u2013
                         Input phase offset value.\n

Returns:

  • phase_only ( tensor ) \u2013

    Constrained output phase.

Source code in odak/learn/wave/optimizers.py
def double_phase_constrain(self, phase, phase_offset):\n    \"\"\"\n    Internal function to constrain a given phase similarly to double phase encoding.\n\n    Parameters\n    ----------\n    phase                      : torch.tensor\n                                 Input phase values to be constrained.\n    phase_offset               : torch.tensor\n                                 Input phase offset value.\n\n    Returns\n    -------\n    phase_only                 : torch.tensor\n                                 Constrained output phase.\n    \"\"\"\n    phase_zero_mean = phase - torch.mean(phase)\n    phase_low = torch.nan_to_num(phase_zero_mean - phase_offset, nan = 2 * np.pi)\n    phase_high = torch.nan_to_num(phase_zero_mean + phase_offset, nan = 2 * np.pi)\n    loss = multi_scale_total_variation_loss(phase_low, levels = 6)\n    loss += multi_scale_total_variation_loss(phase_high, levels = 6)\n    loss += torch.std(phase_low)\n    loss += torch.std(phase_high)\n    phase_only = torch.zeros_like(phase)\n    phase_only[0::2, 0::2] = phase_low[0::2, 0::2]\n    phase_only[0::2, 1::2] = phase_high[0::2, 1::2]\n    phase_only[1::2, 0::2] = phase_high[1::2, 0::2]\n    phase_only[1::2, 1::2] = phase_low[1::2, 1::2]\n    return phase_only, loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.evaluate","title":"evaluate(input_image, target_image, plane_id=0)","text":"

Internal function to evaluate the loss.

Source code in odak/learn/wave/optimizers.py
def evaluate(self, input_image, target_image, plane_id = 0):\n    \"\"\"\n    Internal function to evaluate the loss.\n    \"\"\"\n    if self.loss_type == 'conventional':\n        loss = self.loss_function(input_image, target_image)\n    elif self.loss_type == 'custom':\n        loss = 0\n        for i in range(len(self.wavelengths)):\n            loss += self.loss_function(\n                                       input_image[i],\n                                       target_image[i],\n                                       plane_id = plane_id\n                                      )\n    return loss\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.gradient_descent","title":"gradient_descent(number_of_iterations=100, weights=[1.0, 1.0, 0.0, 0.0])","text":"

Function to optimize multiplane phase-only holograms using stochastic gradient descent.

Parameters:

  • number_of_iterations \u2013
                         Number of iterations.\n
  • weights \u2013
                         Weights used in the loss function.\n

Returns:

  • hologram ( tensor ) \u2013

    Optimised hologram.

Source code in odak/learn/wave/optimizers.py
def gradient_descent(self, number_of_iterations=100, weights=[1., 1., 0., 0.]):\n    \"\"\"\n    Function to optimize multiplane phase-only holograms using stochastic gradient descent.\n\n    Parameters\n    ----------\n    number_of_iterations       : float\n                                 Number of iterations.\n    weights                    : list\n                                 Weights used in the loss function.\n\n    Returns\n    -------\n    hologram                   : torch.tensor\n                                 Optimised hologram.\n    \"\"\"\n    hologram_phases = torch.zeros(\n                                  self.number_of_frames,\n                                  self.resolution[0],\n                                  self.resolution[1],\n                                  device = self.device\n                                 )\n    t = tqdm(range(number_of_iterations), leave = False, dynamic_ncols = True)\n    if self.optimize_peak_amplitude:\n        peak_amp_cache = self.peak_amplitude.item()\n    for step in t:\n        for g in self.optimizer.param_groups:\n            g['lr'] -= (self.learning_rate - self.learning_rate_floor) / number_of_iterations\n            if g['lr'] < self.learning_rate_floor:\n                g['lr'] = self.learning_rate_floor\n            learning_rate = g['lr']\n        total_loss = 0\n        t_depth = tqdm(range(self.targets.shape[0]), leave = False, dynamic_ncols = True)\n        for depth_id in t_depth:\n            self.optimizer.zero_grad()\n            depth_target = self.targets[depth_id]\n            reconstruction_intensities = torch.zeros(\n                                                     self.number_of_frames,\n                                                     self.number_of_channels,\n                                                     self.resolution[0] * self.scale_factor,\n                                                     self.resolution[1] * self.scale_factor,\n                                                     device = self.device\n                                                    )\n            loss_variation_hologram = 0\n            laser_powers = self.propagator.get_laser_powers()\n            for frame_id in range(self.number_of_frames):\n                if self.double_phase:\n                    phase, loss_phase = self.double_phase_constrain(\n                                                                    self.phase[frame_id],\n                                                                    self.offset[frame_id]\n                                                                   )\n                else:\n                    phase, loss_phase = self.direct_phase_constrain(\n                                                                    self.phase[frame_id],\n                                                                    self.offset[frame_id]\n                                                                   )\n                loss_variation_hologram += loss_phase\n                for channel_id in range(self.number_of_channels):\n                    phase_scaled = torch.zeros_like(self.amplitude)\n                    phase_scaled[::self.scale_factor, ::self.scale_factor] = phase\n                    laser_power = laser_powers[frame_id][channel_id]\n                    hologram = generate_complex_field(\n                                                      laser_power * self.amplitude,\n                                                      phase_scaled * self.phase_scale[channel_id]\n                                                     )\n                    reconstruction_field = self.propagator(hologram, channel_id, depth_id)\n                    intensity = calculate_amplitude(reconstruction_field) ** 2\n                    reconstruction_intensities[frame_id, channel_id] += intensity\n                hologram_phases[frame_id] = phase.detach().clone()\n            loss_laser = self.l2_loss(\n                                      torch.amax(depth_target, dim = (1, 2)) * self.peak_amplitude,\n                                      torch.sum(laser_powers, dim = 0)\n                                     )\n            loss_laser += self.l2_loss(\n                                       torch.tensor([self.number_of_frames * self.peak_amplitude]).to(self.device),\n                                       torch.sum(laser_powers).view(1,)\n                                      )\n            loss_laser += torch.cos(torch.min(torch.sum(laser_powers, dim = 1)))\n            reconstruction_intensity = torch.sum(reconstruction_intensities, dim=0)\n            loss_image = self.evaluate(\n                                       reconstruction_intensity,\n                                       depth_target * self.peak_amplitude,\n                                       plane_id = depth_id\n                                      )\n            loss = weights[0] * loss_image\n            loss += weights[1] * loss_laser\n            loss += weights[2] * loss_variation_hologram\n            include_pa_loss_flag = self.optimize_peak_amplitude and loss_image < self.img_loss_thres\n            if include_pa_loss_flag:\n                loss -= self.peak_amplitude * 1.\n            if self.method == 'conventional':\n                loss.backward()\n            else:\n                loss.backward(retain_graph = True)\n            self.optimizer.step()\n            if include_pa_loss_flag:\n                peak_amp_cache = self.peak_amplitude.item()\n            else:\n                with torch.no_grad():\n                    if self.optimize_peak_amplitude:\n                        self.peak_amplitude.view([1])[0] = peak_amp_cache\n            total_loss += loss.detach().item()\n            loss_image = loss_image.detach()\n            del loss_laser\n            del loss_variation_hologram\n            del loss\n        description = \"Loss:{:.3f} Loss Image:{:.3f} Peak Amp:{:.1f} Learning rate:{:.4f}\".format(total_loss, loss_image.item(), self.peak_amplitude, learning_rate)\n        t.set_description(description)\n        del total_loss\n        del loss_image\n        del reconstruction_field\n        del reconstruction_intensities\n        del intensity\n        del phase\n        del hologram\n    logging.warning(description)\n    return hologram_phases.detach()\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_amplitude","title":"init_amplitude()","text":"

Internal function to set the amplitude of the illumination source.

Source code in odak/learn/wave/optimizers.py
def init_amplitude(self):\n    \"\"\"\n    Internal function to set the amplitude of the illumination source.\n    \"\"\"\n    self.amplitude = torch.zeros(\n                                 self.resolution[0] * self.scale_factor,\n                                 self.resolution[1] * self.scale_factor,\n                                 requires_grad = False,\n                                 device = self.device\n                                )\n    self.amplitude[::self.scale_factor, ::self.scale_factor] = 1.\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_channel_power","title":"init_channel_power()","text":"

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_channel_power(self):\n    \"\"\"\n    Internal function to set the starting phase of the phase-only hologram.\n    \"\"\"\n    if self.method == 'conventional':\n        logging.warning('Scheme: Conventional')\n        self.channel_power = torch.eye(\n                                       self.number_of_frames,\n                                       self.number_of_channels,\n                                       device = self.device,\n                                       requires_grad = False\n                                      )\n\n    elif self.method == 'multi-color':\n        logging.warning('Scheme: Multi-color')\n        self.channel_power = torch.ones(\n                                        self.number_of_frames,\n                                        self.number_of_channels,\n                                        device = self.device,\n                                        requires_grad = True\n                                       )\n    if self.channel_power_filename != '':\n        self.channel_power = torch_load(self.channel_power_filename).to(self.device)\n        self.channel_power.requires_grad = False\n        self.channel_power[self.channel_power < 0.] = 0.\n        self.channel_power[self.channel_power > 1.] = 1.\n        if self.method == 'multi-color':\n            self.channel_power.requires_grad = True\n        if self.method == 'conventional':\n            self.channel_power = torch.abs(torch.cos(self.channel_power))\n        logging.warning('Channel powers:')\n        logging.warning(self.channel_power)\n        logging.warning('Channel powers loaded from {}.'.format(self.channel_power_filename))\n    self.propagator.set_laser_powers(self.channel_power)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_loss_function","title":"init_loss_function(loss_function, reduction='sum')","text":"

Internal function to set the loss function.

Source code in odak/learn/wave/optimizers.py
def init_loss_function(self, loss_function, reduction = 'sum'):\n    \"\"\"\n    Internal function to set the loss function.\n    \"\"\"\n    self.l2_loss = torch.nn.MSELoss(reduction = reduction)\n    self.loss_type = 'custom'\n    self.loss_function = loss_function\n    if isinstance(self.loss_function, type(None)):\n        self.loss_type = 'conventional'\n        self.loss_function = torch.nn.MSELoss(reduction = reduction)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_optimizer","title":"init_optimizer()","text":"

Internal function to set the optimizer.

Source code in odak/learn/wave/optimizers.py
def init_optimizer(self):\n    \"\"\"\n    Internal function to set the optimizer.\n    \"\"\"\n    optimization_variables = [self.phase, self.offset]\n    if self.optimize_peak_amplitude:\n        optimization_variables.append(self.peak_amplitude)\n    if self.method == 'multi-color':\n        optimization_variables.append(self.propagator.channel_power)\n    self.optimizer = torch.optim.Adam(optimization_variables, lr=self.learning_rate)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_peak_amplitude_scale","title":"init_peak_amplitude_scale()","text":"

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_peak_amplitude_scale(self):\n    \"\"\"\n    Internal function to set the phase scale.\n    \"\"\"\n    self.peak_amplitude = torch.tensor(\n                                       self.peak_amplitude,\n                                       requires_grad = True,\n                                       device=self.device\n                                      )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_phase","title":"init_phase()","text":"

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def init_phase(self):\n    \"\"\"\n    Internal function to set the starting phase of the phase-only hologram.\n    \"\"\"\n    self.phase = torch.zeros(\n                             self.number_of_frames,\n                             self.resolution[0],\n                             self.resolution[1],\n                             device = self.device,\n                             requires_grad = True\n                            )\n    self.offset = torch.rand_like(self.phase, requires_grad = True, device = self.device)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.init_phase_scale","title":"init_phase_scale()","text":"

Internal function to set the phase scale.

Source code in odak/learn/wave/optimizers.py
def init_phase_scale(self):\n    \"\"\"\n    Internal function to set the phase scale.\n    \"\"\"\n    if self.method == 'conventional':\n        self.phase_scale = torch.tensor(\n                                        [\n                                         1.,\n                                         1.,\n                                         1.\n                                        ],\n                                        requires_grad = False,\n                                        device = self.device\n                                       )\n    if self.method == 'multi-color':\n        self.phase_scale = torch.tensor(\n                                        [\n                                         1.,\n                                         1.,\n                                         1.\n                                        ],\n                                        requires_grad = False,\n                                        device = self.device\n                                       )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.optimizers.multi_color_hologram_optimizer.optimize","title":"optimize(number_of_iterations=100, weights=[1.0, 1.0, 1.0], bits=8)","text":"

Function to optimize multiplane phase-only holograms.

Parameters:

  • number_of_iterations \u2013
                         Number of iterations.\n
  • weights \u2013
                         Loss weights.\n
  • bits \u2013
                         Quantizes the hologram using the given bits and reconstructs.\n

Returns:

  • hologram_phases ( tensor ) \u2013

    Phases of the optimized phase-only hologram.

  • reconstruction_intensities ( tensor ) \u2013

    Intensities of the images reconstructed at each plane with the optimized phase-only hologram.

Source code in odak/learn/wave/optimizers.py
def optimize(self, number_of_iterations=100, weights=[1., 1., 1.], bits = 8):\n    \"\"\"\n    Function to optimize multiplane phase-only holograms.\n\n    Parameters\n    ----------\n    number_of_iterations       : int\n                                 Number of iterations.\n    weights                    : list\n                                 Loss weights.\n    bits                       : int\n                                 Quantizes the hologram using the given bits and reconstructs.\n\n    Returns\n    -------\n    hologram_phases            : torch.tensor\n                                 Phases of the optimized phase-only hologram.\n    reconstruction_intensities : torch.tensor\n                                 Intensities of the images reconstructed at each plane with the optimized phase-only hologram.\n    \"\"\"\n    self.init_optimizer()\n    hologram_phases = self.gradient_descent(\n                                            number_of_iterations=number_of_iterations,\n                                            weights=weights\n                                           )\n    hologram_phases = quantize(hologram_phases % (2 * np.pi), bits = bits, limits = [0., 2 * np.pi]) / 2 ** bits * 2 * np.pi\n    torch.no_grad()\n    reconstruction_intensities = self.propagator.reconstruct(hologram_phases)\n    laser_powers = self.propagator.get_laser_powers()\n    channel_powers = self.propagator.channel_power\n    logging.warning(\"Final peak amplitude: {}\".format(self.peak_amplitude))\n    logging.warning('Laser powers: {}'.format(laser_powers))\n    return hologram_phases, reconstruction_intensities, laser_powers, channel_powers, float(self.peak_amplitude)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator","title":"propagator","text":"

A light propagation model that propagates light to desired image plane with two separate propagations. We use this class in our various works including Kavakl\u0131 et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography.

Source code in odak/learn/wave/propagators.py
class propagator():\n    \"\"\"\n    A light propagation model that propagates light to desired image plane with two separate propagations. \n    We use this class in our various works including `Kavakl\u0131 et al., Realistic Defocus Blur for Multiplane Computer-Generated Holography`.\n    \"\"\"\n    def __init__(\n                 self,\n                 resolution = [1920, 1080],\n                 wavelengths = [515e-9,],\n                 pixel_pitch = 8e-6,\n                 resolution_factor = 1,\n                 number_of_frames = 1,\n                 number_of_depth_layers = 1,\n                 volume_depth = 1e-2,\n                 image_location_offset = 5e-3,\n                 propagation_type = 'Bandlimited Angular Spectrum',\n                 propagator_type = 'back and forth',\n                 back_and_forth_distance = 0.3,\n                 laser_channel_power = None,\n                 aperture = None,\n                 aperture_size = None,\n                 distances = None,\n                 aperture_samples = [20, 20, 5, 5],\n                 method = 'conventional',\n                 device = torch.device('cpu')\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        resolution              : list\n                                  Resolution.\n        wavelengths             : float\n                                  Wavelength of light in meters.\n        pixel_pitch             : float\n                                  Pixel pitch in meters.\n        resolution_factor       : int\n                                  Resolution factor for scaled simulations.\n        number_of_frames        : int\n                                  Number of hologram frames.\n                                  Typically, there are three frames, each one for a single color primary.\n        number_of_depth_layers  : int\n                                  Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.\n        volume_depth            : float\n                                  Width of the volume along the propagation direction.\n        image_location_offset   : float\n                                  Center of the volume along the propagation direction.\n        propagation_type        : str\n                                  Propagation type. \n                                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.\n        propagator_type         : str\n                                  Propagator type.\n                                  The options are `back and forth` and `forward` propagators.\n        back_and_forth_distance : float\n                                  Zero mode distance for `back and forth` propagator type.\n        laser_channel_power     : torch.tensor\n                                  Laser channel powers for given number of frames and number of wavelengths.\n        aperture                : torch.tensor\n                                  Aperture at the Fourier plane.\n        aperture_size           : float\n                                  Aperture width for a circular aperture.\n        aperture_samples        : list\n                                  When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n        distances               : torch.tensor\n                                  Propagation distances in meters.\n        method                  : str\n                                  Hologram type conventional or multi-color.\n        device                  : torch.device\n                                  Device to be used for computation. For more see torch.device().\n        \"\"\"\n        self.device = device\n        self.pixel_pitch = pixel_pitch\n        self.wavelengths = wavelengths\n        self.resolution = resolution\n        self.propagation_type = propagation_type\n        if self.propagation_type != 'Impulse Response Fresnel':\n            resolution_factor = 1\n        self.resolution_factor = resolution_factor\n        self.number_of_frames = number_of_frames\n        self.number_of_depth_layers = number_of_depth_layers\n        self.number_of_channels = len(self.wavelengths)\n        self.volume_depth = volume_depth\n        self.image_location_offset = image_location_offset\n        self.propagator_type = propagator_type\n        self.aperture_samples = aperture_samples\n        self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)\n        self.method = method\n        self.aperture = aperture\n        self.init_distances(distances)\n        self.init_kernels()\n        self.init_channel_power(laser_channel_power)\n        self.init_phase_scale()\n        self.set_aperture(aperture, aperture_size)\n\n\n    def init_distances(self, distances):\n        \"\"\"\n        Internal function to initialize distances.\n\n        Parameters\n        ----------\n        distances               : torch.tensor\n                                  Propagation distances.\n        \"\"\"\n        if isinstance(distances, type(None)):\n            self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset\n        else:\n            self.distances = torch.as_tensor(distances)\n            self.number_of_depth_layers = self.distances.shape[0]\n        logging.warning('Distances: {}'.format(self.distances))\n\n\n    def init_kernels(self):\n        \"\"\"\n        Internal function to initialize kernels.\n        \"\"\"\n        self.generated_kernels = torch.zeros(\n                                             self.number_of_depth_layers,\n                                             self.number_of_channels,\n                                             device = self.device\n                                            )\n        self.kernels = torch.zeros(\n                                   self.number_of_depth_layers,\n                                   self.number_of_channels,\n                                   self.resolution[0] * self.resolution_factor * 2,\n                                   self.resolution[1] * self.resolution_factor * 2,\n                                   dtype = torch.complex64,\n                                   device = self.device\n                                  )\n\n\n    def init_channel_power(self, channel_power):\n        \"\"\"\n        Internal function to set the starting phase of the phase-only hologram.\n        \"\"\"\n        self.channel_power = channel_power\n        if isinstance(self.channel_power, type(None)):\n            self.channel_power = torch.eye(\n                                           self.number_of_frames,\n                                           self.number_of_channels,\n                                           device = self.device,\n                                           requires_grad = False\n                                          )\n\n\n    def init_phase_scale(self):\n        \"\"\"\n        Internal function to set the phase scale.\n        In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.\n        \"\"\"\n        self.phase_scale = torch.tensor(\n                                        [\n                                         1.,\n                                         1.,\n                                         1.\n                                        ],\n                                        requires_grad = False,\n                                        device = self.device\n                                       )\n\n\n    def set_aperture(self, aperture = None, aperture_size = None):\n        \"\"\"\n        Set aperture in the Fourier plane.\n\n\n        Parameters\n        ----------\n        aperture        : torch.tensor\n                          Aperture at the original resolution of a hologram.\n                          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).\n        aperture_size   : int\n                          If no aperture is provided, this will determine the size of the circular aperture.\n        \"\"\"\n        if isinstance(aperture, type(None)):\n            if isinstance(aperture_size, type(None)):\n                aperture_size = torch.max(\n                                          torch.tensor([\n                                                        self.resolution[0] * self.resolution_factor, \n                                                        self.resolution[1] * self.resolution_factor\n                                                       ])\n                                         )\n            self.aperture = circular_binary_mask(\n                                                 self.resolution[0] * self.resolution_factor * 2,\n                                                 self.resolution[1] * self.resolution_factor * 2,\n                                                 aperture_size,\n                                                ).to(self.device) * 1.\n        else:\n            self.aperture = zero_pad(aperture).to(self.device) * 1.\n\n\n    def get_laser_powers(self):\n        \"\"\"\n        Internal function to get the laser powers.\n\n        Returns\n        -------\n        laser_power      : torch.tensor\n                           Laser powers.\n        \"\"\"\n        if self.method == 'conventional':\n            laser_power = self.channel_power\n        if self.method == 'multi-color':\n            laser_power = torch.abs(torch.cos(self.channel_power))\n        return laser_power\n\n\n    def set_laser_powers(self, laser_power):\n        \"\"\"\n        Internal function to set the laser powers.\n\n        Parameters\n        -------\n        laser_power      : torch.tensor\n                           Laser powers.\n        \"\"\"\n        self.channel_power = laser_power\n\n\n\n    def get_kernels(self):\n        \"\"\"\n        Function to return the kernels used in the light transport.\n\n        Returns\n        -------\n        kernels           : torch.tensor\n                            Kernel amplitudes.\n        \"\"\"\n        h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))\n        kernels_amplitude = calculate_amplitude(h)\n        kernels_phase = calculate_phase(h)\n        return kernels_amplitude, kernels_phase\n\n\n    def __call__(self, input_field, channel_id, depth_id):\n        \"\"\"\n        Function that represents the forward model in hologram optimization.\n\n        Parameters\n        ----------\n        input_field         : torch.tensor\n                              Input complex input field.\n        channel_id          : int\n                              Identifying the color primary to be used.\n        depth_id            : int\n                              Identifying the depth layer to be used.\n\n        Returns\n        -------\n        output_field        : torch.tensor\n                              Propagated output complex field.\n        \"\"\"\n        distance = self.distances[depth_id]\n        if not self.generated_kernels[depth_id, channel_id]:\n            if self.propagator_type == 'forward':\n                H = get_propagation_kernel(\n                                           nu = self.resolution[0] * 2,\n                                           nv = self.resolution[1] * 2,\n                                           dx = self.pixel_pitch,\n                                           wavelength = self.wavelengths[channel_id],\n                                           distance = distance,\n                                           device = self.device,\n                                           propagation_type = self.propagation_type,\n                                           samples = self.aperture_samples,\n                                           scale = self.resolution_factor\n                                          )\n            elif self.propagator_type == 'back and forth':\n                H_forward = get_propagation_kernel(\n                                                   nu = self.resolution[0] * 2,\n                                                   nv = self.resolution[1] * 2,\n                                                   dx = self.pixel_pitch,\n                                                   wavelength = self.wavelengths[channel_id],\n                                                   distance = self.zero_mode_distance,\n                                                   device = self.device,\n                                                   propagation_type = self.propagation_type,\n                                                   samples = self.aperture_samples,\n                                                   scale = self.resolution_factor\n                                                  )\n                distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)\n                H_back = get_propagation_kernel(\n                                                nu = self.resolution[0] * 2,\n                                                nv = self.resolution[1] * 2,\n                                                dx = self.pixel_pitch,\n                                                wavelength = self.wavelengths[channel_id],\n                                                distance = distance_back,\n                                                device = self.device,\n                                                propagation_type = self.propagation_type,\n                                                samples = self.aperture_samples,\n                                                scale = self.resolution_factor\n                                               )\n                H = H_forward * H_back\n            self.kernels[depth_id, channel_id] = H\n            self.generated_kernels[depth_id, channel_id] = True\n        else:\n            H = self.kernels[depth_id, channel_id].detach().clone()\n        field_scale = input_field\n        field_scale_padded = zero_pad(field_scale)\n        output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)\n        output_field = crop_center(output_field_padded)\n        return output_field\n\n\n    def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):\n        \"\"\"\n        Internal function to reconstruct a given hologram.\n\n\n        Parameters\n        ----------\n        hologram_phases            : torch.tensor\n                                     Hologram phases [ch x m x n].\n        amplitude                  : torch.tensor\n                                     Amplitude profiles for each color primary [ch x m x n]\n        no_grad                    : bool\n                                     If set True, uses torch.no_grad in reconstruction.\n        get_complex                : bool\n                                     If set True, reconstructor returns the complex field but not the intensities.\n\n        Returns\n        -------\n        reconstructions            : torch.tensor\n                                     Reconstructed frames.\n        \"\"\"\n        if no_grad:\n            torch.no_grad()\n        if len(hologram_phases.shape) > 3:\n            hologram_phases = hologram_phases.squeeze(0)\n        if get_complex == True:\n            reconstruction_type = torch.complex64\n        else:\n            reconstruction_type = torch.float32\n        reconstructions = torch.zeros(\n                                      self.number_of_frames,\n                                      self.number_of_depth_layers,\n                                      self.number_of_channels,\n                                      self.resolution[0] * self.resolution_factor,\n                                      self.resolution[1] * self.resolution_factor,\n                                      dtype = reconstruction_type,\n                                      device = self.device\n                                     )\n        if isinstance(amplitude, type(None)):\n            amplitude = torch.zeros(\n                                    self.number_of_channels,\n                                    self.resolution[0] * self.resolution_factor,\n                                    self.resolution[1] * self.resolution_factor,\n                                    device = self.device\n                                   )\n            amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.\n        if self.resolution_factor != 1:\n            hologram_phases_scaled = torch.zeros_like(amplitude)\n            hologram_phases_scaled[\n                                   :,\n                                   ::self.resolution_factor,\n                                   ::self.resolution_factor\n                                  ] = hologram_phases\n        else:\n            hologram_phases_scaled = hologram_phases\n        for frame_id in range(self.number_of_frames):\n            for depth_id in range(self.number_of_depth_layers):\n                for channel_id in range(self.number_of_channels):\n                    laser_power = self.get_laser_powers()[frame_id][channel_id]\n                    phase = hologram_phases_scaled[frame_id]\n                    hologram = generate_complex_field(\n                                                      laser_power * amplitude[channel_id],\n                                                      phase * self.phase_scale[channel_id]\n                                                     )\n                    reconstruction_field = self.__call__(hologram, channel_id, depth_id)\n                    if get_complex == True:\n                        result = reconstruction_field\n                    else:\n                        result = calculate_amplitude(reconstruction_field) ** 2\n                    reconstructions[\n                                    frame_id,\n                                    depth_id,\n                                    channel_id\n                                   ] = result.detach().clone()\n        return reconstructions\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.__call__","title":"__call__(input_field, channel_id, depth_id)","text":"

Function that represents the forward model in hologram optimization.

Parameters:

  • input_field \u2013
                  Input complex input field.\n
  • channel_id \u2013
                  Identifying the color primary to be used.\n
  • depth_id \u2013
                  Identifying the depth layer to be used.\n

Returns:

  • output_field ( tensor ) \u2013

    Propagated output complex field.

Source code in odak/learn/wave/propagators.py
def __call__(self, input_field, channel_id, depth_id):\n    \"\"\"\n    Function that represents the forward model in hologram optimization.\n\n    Parameters\n    ----------\n    input_field         : torch.tensor\n                          Input complex input field.\n    channel_id          : int\n                          Identifying the color primary to be used.\n    depth_id            : int\n                          Identifying the depth layer to be used.\n\n    Returns\n    -------\n    output_field        : torch.tensor\n                          Propagated output complex field.\n    \"\"\"\n    distance = self.distances[depth_id]\n    if not self.generated_kernels[depth_id, channel_id]:\n        if self.propagator_type == 'forward':\n            H = get_propagation_kernel(\n                                       nu = self.resolution[0] * 2,\n                                       nv = self.resolution[1] * 2,\n                                       dx = self.pixel_pitch,\n                                       wavelength = self.wavelengths[channel_id],\n                                       distance = distance,\n                                       device = self.device,\n                                       propagation_type = self.propagation_type,\n                                       samples = self.aperture_samples,\n                                       scale = self.resolution_factor\n                                      )\n        elif self.propagator_type == 'back and forth':\n            H_forward = get_propagation_kernel(\n                                               nu = self.resolution[0] * 2,\n                                               nv = self.resolution[1] * 2,\n                                               dx = self.pixel_pitch,\n                                               wavelength = self.wavelengths[channel_id],\n                                               distance = self.zero_mode_distance,\n                                               device = self.device,\n                                               propagation_type = self.propagation_type,\n                                               samples = self.aperture_samples,\n                                               scale = self.resolution_factor\n                                              )\n            distance_back = -(self.zero_mode_distance + self.image_location_offset - distance)\n            H_back = get_propagation_kernel(\n                                            nu = self.resolution[0] * 2,\n                                            nv = self.resolution[1] * 2,\n                                            dx = self.pixel_pitch,\n                                            wavelength = self.wavelengths[channel_id],\n                                            distance = distance_back,\n                                            device = self.device,\n                                            propagation_type = self.propagation_type,\n                                            samples = self.aperture_samples,\n                                            scale = self.resolution_factor\n                                           )\n            H = H_forward * H_back\n        self.kernels[depth_id, channel_id] = H\n        self.generated_kernels[depth_id, channel_id] = True\n    else:\n        H = self.kernels[depth_id, channel_id].detach().clone()\n    field_scale = input_field\n    field_scale_padded = zero_pad(field_scale)\n    output_field_padded = custom(field_scale_padded, H, aperture = self.aperture)\n    output_field = crop_center(output_field_padded)\n    return output_field\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.__init__","title":"__init__(resolution=[1920, 1080], wavelengths=[5.15e-07], pixel_pitch=8e-06, resolution_factor=1, number_of_frames=1, number_of_depth_layers=1, volume_depth=0.01, image_location_offset=0.005, propagation_type='Bandlimited Angular Spectrum', propagator_type='back and forth', back_and_forth_distance=0.3, laser_channel_power=None, aperture=None, aperture_size=None, distances=None, aperture_samples=[20, 20, 5, 5], method='conventional', device=torch.device('cpu'))","text":"

Parameters:

  • resolution \u2013
                      Resolution.\n
  • wavelengths \u2013
                      Wavelength of light in meters.\n
  • pixel_pitch \u2013
                      Pixel pitch in meters.\n
  • resolution_factor \u2013
                      Resolution factor for scaled simulations.\n
  • number_of_frames \u2013
                      Number of hologram frames.\n                  Typically, there are three frames, each one for a single color primary.\n
  • number_of_depth_layers \u2013
                      Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.\n
  • volume_depth \u2013
                      Width of the volume along the propagation direction.\n
  • image_location_offset \u2013
                      Center of the volume along the propagation direction.\n
  • propagation_type \u2013
                      Propagation type. \n                  See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.\n
  • propagator_type \u2013
                      Propagator type.\n                  The options are `back and forth` and `forward` propagators.\n
  • back_and_forth_distance (float, default: 0.3 ) \u2013
                      Zero mode distance for `back and forth` propagator type.\n
  • laser_channel_power \u2013
                      Laser channel powers for given number of frames and number of wavelengths.\n
  • aperture \u2013
                      Aperture at the Fourier plane.\n
  • aperture_size \u2013
                      Aperture width for a circular aperture.\n
  • aperture_samples \u2013
                      When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n
  • distances \u2013
                      Propagation distances in meters.\n
  • method \u2013
                      Hologram type conventional or multi-color.\n
  • device \u2013
                      Device to be used for computation. For more see torch.device().\n
Source code in odak/learn/wave/propagators.py
def __init__(\n             self,\n             resolution = [1920, 1080],\n             wavelengths = [515e-9,],\n             pixel_pitch = 8e-6,\n             resolution_factor = 1,\n             number_of_frames = 1,\n             number_of_depth_layers = 1,\n             volume_depth = 1e-2,\n             image_location_offset = 5e-3,\n             propagation_type = 'Bandlimited Angular Spectrum',\n             propagator_type = 'back and forth',\n             back_and_forth_distance = 0.3,\n             laser_channel_power = None,\n             aperture = None,\n             aperture_size = None,\n             distances = None,\n             aperture_samples = [20, 20, 5, 5],\n             method = 'conventional',\n             device = torch.device('cpu')\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    resolution              : list\n                              Resolution.\n    wavelengths             : float\n                              Wavelength of light in meters.\n    pixel_pitch             : float\n                              Pixel pitch in meters.\n    resolution_factor       : int\n                              Resolution factor for scaled simulations.\n    number_of_frames        : int\n                              Number of hologram frames.\n                              Typically, there are three frames, each one for a single color primary.\n    number_of_depth_layers  : int\n                              Equ-distance number of depth layers within the desired volume. If `distances` parameter is passed, this value will be automatically set to the length of the `distances` verson provided.\n    volume_depth            : float\n                              Width of the volume along the propagation direction.\n    image_location_offset   : float\n                              Center of the volume along the propagation direction.\n    propagation_type        : str\n                              Propagation type. \n                              See ropagate_beam() and odak.learn.wave.get_propagation_kernel() for more.\n    propagator_type         : str\n                              Propagator type.\n                              The options are `back and forth` and `forward` propagators.\n    back_and_forth_distance : float\n                              Zero mode distance for `back and forth` propagator type.\n    laser_channel_power     : torch.tensor\n                              Laser channel powers for given number of frames and number of wavelengths.\n    aperture                : torch.tensor\n                              Aperture at the Fourier plane.\n    aperture_size           : float\n                              Aperture width for a circular aperture.\n    aperture_samples        : list\n                              When using `Impulse Response Fresnel` propagation, these sample counts along X and Y will be used to represent a rectangular aperture. First two is for hologram plane pixel and the last two is for image plane pixel.\n    distances               : torch.tensor\n                              Propagation distances in meters.\n    method                  : str\n                              Hologram type conventional or multi-color.\n    device                  : torch.device\n                              Device to be used for computation. For more see torch.device().\n    \"\"\"\n    self.device = device\n    self.pixel_pitch = pixel_pitch\n    self.wavelengths = wavelengths\n    self.resolution = resolution\n    self.propagation_type = propagation_type\n    if self.propagation_type != 'Impulse Response Fresnel':\n        resolution_factor = 1\n    self.resolution_factor = resolution_factor\n    self.number_of_frames = number_of_frames\n    self.number_of_depth_layers = number_of_depth_layers\n    self.number_of_channels = len(self.wavelengths)\n    self.volume_depth = volume_depth\n    self.image_location_offset = image_location_offset\n    self.propagator_type = propagator_type\n    self.aperture_samples = aperture_samples\n    self.zero_mode_distance = torch.tensor(back_and_forth_distance, device = device)\n    self.method = method\n    self.aperture = aperture\n    self.init_distances(distances)\n    self.init_kernels()\n    self.init_channel_power(laser_channel_power)\n    self.init_phase_scale()\n    self.set_aperture(aperture, aperture_size)\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.get_kernels","title":"get_kernels()","text":"

Function to return the kernels used in the light transport.

Returns:

  • kernels ( tensor ) \u2013

    Kernel amplitudes.

Source code in odak/learn/wave/propagators.py
def get_kernels(self):\n    \"\"\"\n    Function to return the kernels used in the light transport.\n\n    Returns\n    -------\n    kernels           : torch.tensor\n                        Kernel amplitudes.\n    \"\"\"\n    h = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.ifftshift(self.kernels)))\n    kernels_amplitude = calculate_amplitude(h)\n    kernels_phase = calculate_phase(h)\n    return kernels_amplitude, kernels_phase\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.get_laser_powers","title":"get_laser_powers()","text":"

Internal function to get the laser powers.

Returns:

  • laser_power ( tensor ) \u2013

    Laser powers.

Source code in odak/learn/wave/propagators.py
def get_laser_powers(self):\n    \"\"\"\n    Internal function to get the laser powers.\n\n    Returns\n    -------\n    laser_power      : torch.tensor\n                       Laser powers.\n    \"\"\"\n    if self.method == 'conventional':\n        laser_power = self.channel_power\n    if self.method == 'multi-color':\n        laser_power = torch.abs(torch.cos(self.channel_power))\n    return laser_power\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.init_channel_power","title":"init_channel_power(channel_power)","text":"

Internal function to set the starting phase of the phase-only hologram.

Source code in odak/learn/wave/propagators.py
def init_channel_power(self, channel_power):\n    \"\"\"\n    Internal function to set the starting phase of the phase-only hologram.\n    \"\"\"\n    self.channel_power = channel_power\n    if isinstance(self.channel_power, type(None)):\n        self.channel_power = torch.eye(\n                                       self.number_of_frames,\n                                       self.number_of_channels,\n                                       device = self.device,\n                                       requires_grad = False\n                                      )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.init_distances","title":"init_distances(distances)","text":"

Internal function to initialize distances.

Parameters:

  • distances \u2013
                      Propagation distances.\n
Source code in odak/learn/wave/propagators.py
def init_distances(self, distances):\n    \"\"\"\n    Internal function to initialize distances.\n\n    Parameters\n    ----------\n    distances               : torch.tensor\n                              Propagation distances.\n    \"\"\"\n    if isinstance(distances, type(None)):\n        self.distances = torch.linspace(-self.volume_depth / 2., self.volume_depth / 2., self.number_of_depth_layers) + self.image_location_offset\n    else:\n        self.distances = torch.as_tensor(distances)\n        self.number_of_depth_layers = self.distances.shape[0]\n    logging.warning('Distances: {}'.format(self.distances))\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.init_kernels","title":"init_kernels()","text":"

Internal function to initialize kernels.

Source code in odak/learn/wave/propagators.py
def init_kernels(self):\n    \"\"\"\n    Internal function to initialize kernels.\n    \"\"\"\n    self.generated_kernels = torch.zeros(\n                                         self.number_of_depth_layers,\n                                         self.number_of_channels,\n                                         device = self.device\n                                        )\n    self.kernels = torch.zeros(\n                               self.number_of_depth_layers,\n                               self.number_of_channels,\n                               self.resolution[0] * self.resolution_factor * 2,\n                               self.resolution[1] * self.resolution_factor * 2,\n                               dtype = torch.complex64,\n                               device = self.device\n                              )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.init_phase_scale","title":"init_phase_scale()","text":"

Internal function to set the phase scale. In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.

Source code in odak/learn/wave/propagators.py
def init_phase_scale(self):\n    \"\"\"\n    Internal function to set the phase scale.\n    In some cases, you may want to modify this init to ratio phases for different color primaries as an SLM is configured for a specific central wavelength.\n    \"\"\"\n    self.phase_scale = torch.tensor(\n                                    [\n                                     1.,\n                                     1.,\n                                     1.\n                                    ],\n                                    requires_grad = False,\n                                    device = self.device\n                                   )\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.reconstruct","title":"reconstruct(hologram_phases, amplitude=None, no_grad=True, get_complex=False)","text":"

Internal function to reconstruct a given hologram.

Parameters:

  • hologram_phases \u2013
                         Hologram phases [ch x m x n].\n
  • amplitude \u2013
                         Amplitude profiles for each color primary [ch x m x n]\n
  • no_grad \u2013
                         If set True, uses torch.no_grad in reconstruction.\n
  • get_complex \u2013
                         If set True, reconstructor returns the complex field but not the intensities.\n

Returns:

  • reconstructions ( tensor ) \u2013

    Reconstructed frames.

Source code in odak/learn/wave/propagators.py
def reconstruct(self, hologram_phases, amplitude = None, no_grad = True, get_complex = False):\n    \"\"\"\n    Internal function to reconstruct a given hologram.\n\n\n    Parameters\n    ----------\n    hologram_phases            : torch.tensor\n                                 Hologram phases [ch x m x n].\n    amplitude                  : torch.tensor\n                                 Amplitude profiles for each color primary [ch x m x n]\n    no_grad                    : bool\n                                 If set True, uses torch.no_grad in reconstruction.\n    get_complex                : bool\n                                 If set True, reconstructor returns the complex field but not the intensities.\n\n    Returns\n    -------\n    reconstructions            : torch.tensor\n                                 Reconstructed frames.\n    \"\"\"\n    if no_grad:\n        torch.no_grad()\n    if len(hologram_phases.shape) > 3:\n        hologram_phases = hologram_phases.squeeze(0)\n    if get_complex == True:\n        reconstruction_type = torch.complex64\n    else:\n        reconstruction_type = torch.float32\n    reconstructions = torch.zeros(\n                                  self.number_of_frames,\n                                  self.number_of_depth_layers,\n                                  self.number_of_channels,\n                                  self.resolution[0] * self.resolution_factor,\n                                  self.resolution[1] * self.resolution_factor,\n                                  dtype = reconstruction_type,\n                                  device = self.device\n                                 )\n    if isinstance(amplitude, type(None)):\n        amplitude = torch.zeros(\n                                self.number_of_channels,\n                                self.resolution[0] * self.resolution_factor,\n                                self.resolution[1] * self.resolution_factor,\n                                device = self.device\n                               )\n        amplitude[:, ::self.resolution_factor, ::self.resolution_factor] = 1.\n    if self.resolution_factor != 1:\n        hologram_phases_scaled = torch.zeros_like(amplitude)\n        hologram_phases_scaled[\n                               :,\n                               ::self.resolution_factor,\n                               ::self.resolution_factor\n                              ] = hologram_phases\n    else:\n        hologram_phases_scaled = hologram_phases\n    for frame_id in range(self.number_of_frames):\n        for depth_id in range(self.number_of_depth_layers):\n            for channel_id in range(self.number_of_channels):\n                laser_power = self.get_laser_powers()[frame_id][channel_id]\n                phase = hologram_phases_scaled[frame_id]\n                hologram = generate_complex_field(\n                                                  laser_power * amplitude[channel_id],\n                                                  phase * self.phase_scale[channel_id]\n                                                 )\n                reconstruction_field = self.__call__(hologram, channel_id, depth_id)\n                if get_complex == True:\n                    result = reconstruction_field\n                else:\n                    result = calculate_amplitude(reconstruction_field) ** 2\n                reconstructions[\n                                frame_id,\n                                depth_id,\n                                channel_id\n                               ] = result.detach().clone()\n    return reconstructions\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.set_aperture","title":"set_aperture(aperture=None, aperture_size=None)","text":"

Set aperture in the Fourier plane.

Parameters:

  • aperture \u2013
              Aperture at the original resolution of a hologram.\n          If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).\n
  • aperture_size \u2013
              If no aperture is provided, this will determine the size of the circular aperture.\n
Source code in odak/learn/wave/propagators.py
def set_aperture(self, aperture = None, aperture_size = None):\n    \"\"\"\n    Set aperture in the Fourier plane.\n\n\n    Parameters\n    ----------\n    aperture        : torch.tensor\n                      Aperture at the original resolution of a hologram.\n                      If aperture is provided as None, it will assign a circular aperture at the size of the short edge (width or height).\n    aperture_size   : int\n                      If no aperture is provided, this will determine the size of the circular aperture.\n    \"\"\"\n    if isinstance(aperture, type(None)):\n        if isinstance(aperture_size, type(None)):\n            aperture_size = torch.max(\n                                      torch.tensor([\n                                                    self.resolution[0] * self.resolution_factor, \n                                                    self.resolution[1] * self.resolution_factor\n                                                   ])\n                                     )\n        self.aperture = circular_binary_mask(\n                                             self.resolution[0] * self.resolution_factor * 2,\n                                             self.resolution[1] * self.resolution_factor * 2,\n                                             aperture_size,\n                                            ).to(self.device) * 1.\n    else:\n        self.aperture = zero_pad(aperture).to(self.device) * 1.\n
"},{"location":"odak/learn_wave/#odak.learn.wave.propagators.propagator.set_laser_powers","title":"set_laser_powers(laser_power)","text":"

Internal function to set the laser powers.

Parameters:

  • laser_power \u2013
               Laser powers.\n
Source code in odak/learn/wave/propagators.py
def set_laser_powers(self, laser_power):\n    \"\"\"\n    Internal function to set the laser powers.\n\n    Parameters\n    -------\n    laser_power      : torch.tensor\n                       Laser powers.\n    \"\"\"\n    self.channel_power = laser_power\n
"},{"location":"odak/learn_wave/#odak.learn.wave.util.calculate_amplitude","title":"calculate_amplitude(field)","text":"

Definition to calculate amplitude of a single or multiple given electric field(s).

Parameters:

  • field \u2013
           Electric fields or an electric field.\n

Returns:

  • amplitude ( float ) \u2013

    Amplitude or amplitudes of electric field(s).

Source code in odak/learn/wave/util.py
def calculate_amplitude(field):\n    \"\"\" \n    Definition to calculate amplitude of a single or multiple given electric field(s).\n\n    Parameters\n    ----------\n    field        : torch.cfloat\n                   Electric fields or an electric field.\n\n    Returns\n    -------\n    amplitude    : torch.float\n                   Amplitude or amplitudes of electric field(s).\n    \"\"\"\n    amplitude = torch.abs(field)\n    return amplitude\n
"},{"location":"odak/learn_wave/#odak.learn.wave.util.calculate_phase","title":"calculate_phase(field, deg=False)","text":"

Definition to calculate phase of a single or multiple given electric field(s).

Parameters:

  • field \u2013
           Electric fields or an electric field.\n
  • deg \u2013
           If set True, the angles will be returned in degrees.\n

Returns:

  • phase ( float ) \u2013

    Phase or phases of electric field(s) in radians.

Source code in odak/learn/wave/util.py
def calculate_phase(field, deg = False):\n    \"\"\" \n    Definition to calculate phase of a single or multiple given electric field(s).\n\n    Parameters\n    ----------\n    field        : torch.cfloat\n                   Electric fields or an electric field.\n    deg          : bool\n                   If set True, the angles will be returned in degrees.\n\n    Returns\n    -------\n    phase        : torch.float\n                   Phase or phases of electric field(s) in radians.\n    \"\"\"\n    phase = field.imag.atan2(field.real)\n    if deg:\n        phase *= 180. / np.pi\n    return phase\n
"},{"location":"odak/learn_wave/#odak.learn.wave.util.generate_complex_field","title":"generate_complex_field(amplitude, phase)","text":"

Definition to generate a complex field with a given amplitude and phase.

Parameters:

  • amplitude \u2013
                Amplitude of the field.\n            The expected size is [m x n] or [1 x m x n].\n
  • phase \u2013
                Phase of the field.\n            The expected size is [m x n] or [1 x m x n].\n

Returns:

  • field ( ndarray ) \u2013

    Complex field. Depending on the input, the expected size is [m x n] or [1 x m x n].

Source code in odak/learn/wave/util.py
def generate_complex_field(amplitude, phase):\n    \"\"\"\n    Definition to generate a complex field with a given amplitude and phase.\n\n    Parameters\n    ----------\n    amplitude         : torch.tensor\n                        Amplitude of the field.\n                        The expected size is [m x n] or [1 x m x n].\n    phase             : torch.tensor\n                        Phase of the field.\n                        The expected size is [m x n] or [1 x m x n].\n\n    Returns\n    -------\n    field             : ndarray\n                        Complex field.\n                        Depending on the input, the expected size is [m x n] or [1 x m x n].\n    \"\"\"\n    field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)\n    return field\n
"},{"location":"odak/learn_wave/#odak.learn.wave.util.set_amplitude","title":"set_amplitude(field, amplitude)","text":"

Definition to keep phase as is and change the amplitude of a given field.

Parameters:

  • field \u2013
           Complex field.\n
  • amplitude \u2013
           Amplitudes.\n

Returns:

  • new_field ( cfloat ) \u2013

    Complex field.

Source code in odak/learn/wave/util.py
def set_amplitude(field, amplitude):\n    \"\"\"\n    Definition to keep phase as is and change the amplitude of a given field.\n\n    Parameters\n    ----------\n    field        : torch.cfloat\n                   Complex field.\n    amplitude    : torch.cfloat or torch.float\n                   Amplitudes.\n\n    Returns\n    -------\n    new_field    : torch.cfloat\n                   Complex field.\n    \"\"\"\n    amplitude = calculate_amplitude(amplitude)\n    phase = calculate_phase(field)\n    new_field = amplitude * torch.cos(phase) + 1j * amplitude * torch.sin(phase)\n    return new_field\n
"},{"location":"odak/learn_wave/#odak.learn.wave.util.wavenumber","title":"wavenumber(wavelength)","text":"

Definition for calculating the wavenumber of a plane wave.

Parameters:

  • wavelength \u2013
           Wavelength of a wave in mm.\n

Returns:

  • k ( float ) \u2013

    Wave number for a given wavelength.

Source code in odak/learn/wave/util.py
def wavenumber(wavelength):\n    \"\"\"\n    Definition for calculating the wavenumber of a plane wave.\n\n    Parameters\n    ----------\n    wavelength   : float\n                   Wavelength of a wave in mm.\n\n    Returns\n    -------\n    k            : float\n                   Wave number for a given wavelength.\n    \"\"\"\n    k = 2 * np.pi / wavelength\n    return k\n
"},{"location":"odak/raytracing/","title":"odak.raytracing","text":"

odak.raytracing

Provides necessary definitions for geometric optics. See \"General Ray tracing procedure\" from G.H. Spencerand M.V.R.K Murty for the theoratical explanation.

"},{"location":"odak/raytracing/#odak.raytracing.bring_plane_to_origin","title":"bring_plane_to_origin(point, plane, shape=[10.0, 10.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0], mode='XYZ')","text":"

Definition to bring points back to reference origin with respect to a plane.

Parameters:

  • point \u2013
                 Point(s) to be tested.\n
  • shape \u2013
                 Dimensions of the rectangle along X and Y axes.\n
  • center \u2013
                 Center of the rectangle.\n
  • angles \u2013
                 Rotation angle of the rectangle.\n
  • mode \u2013
                 Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.\n

Returns:

  • transformed_points ( ndarray ) \u2013

    Point(s) that are brought back to reference origin with respect to given plane.

Source code in odak/raytracing/primitives.py
def bring_plane_to_origin(point, plane, shape=[10., 10.], center=[0., 0., 0.], angles=[0., 0., 0.], mode='XYZ'):\n    \"\"\"\n    Definition to bring points back to reference origin with respect to a plane.\n\n    Parameters\n    ----------\n    point              : ndarray\n                         Point(s) to be tested.\n    shape              : list\n                         Dimensions of the rectangle along X and Y axes.\n    center             : list\n                         Center of the rectangle.\n    angles             : list\n                         Rotation angle of the rectangle.\n    mode               : str\n                         Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.\n\n    Returns\n    ----------\n    transformed_points : ndarray\n                         Point(s) that are brought back to reference origin with respect to given plane.\n    \"\"\"\n    if point.shape[0] == 3:\n        point = point.reshape((1, 3))\n    reverse_mode = mode[::-1]\n    angles = [-angles[0], -angles[1], -angles[2]]\n    center = np.asarray(center).reshape((1, 3))\n    transformed_points = point-center\n    transformed_points = rotate_points(\n        transformed_points,\n        angles=angles,\n        mode=reverse_mode,\n    )\n    if transformed_points.shape[0] == 1:\n        transformed_points = transformed_points.reshape((3,))\n    return transformed_points\n
"},{"location":"odak/raytracing/#odak.raytracing.calculate_intersection_of_two_rays","title":"calculate_intersection_of_two_rays(ray0, ray1)","text":"

Definition to calculate the intersection of two rays.

Parameters:

  • ray0 \u2013
         A ray.\n
  • ray1 \u2013
         A ray.\n

Returns:

  • point ( ndarray ) \u2013

    Point in X,Y,Z.

  • distances ( ndarray ) \u2013

    Distances.

Source code in odak/raytracing/ray.py
def calculate_intersection_of_two_rays(ray0, ray1):\n    \"\"\"\n    Definition to calculate the intersection of two rays.\n\n    Parameters\n    ----------\n    ray0       : ndarray\n                 A ray.\n    ray1       : ndarray\n                 A ray.\n\n    Returns\n    ----------\n    point      : ndarray\n                 Point in X,Y,Z.\n    distances  : ndarray\n                 Distances.\n    \"\"\"\n    A = np.array([\n        [float(ray0[1][0]), float(ray1[1][0])],\n        [float(ray0[1][1]), float(ray1[1][1])],\n        [float(ray0[1][2]), float(ray1[1][2])]\n    ])\n    B = np.array([\n        ray0[0][0]-ray1[0][0],\n        ray0[0][1]-ray1[0][1],\n        ray0[0][2]-ray1[0][2]\n    ])\n    distances = np.linalg.lstsq(A, B, rcond=None)[0]\n    if np.allclose(np.dot(A, distances), B) == False:\n        distances = np.array([0, 0])\n    distances = distances[np.argsort(-distances)]\n    point = propagate_a_ray(ray0, distances[0])[0]\n    return point, distances\n
"},{"location":"odak/raytracing/#odak.raytracing.center_of_triangle","title":"center_of_triangle(triangle)","text":"

Definition to calculate center of a triangle.

Parameters:

  • triangle \u2013
            An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).\n
Source code in odak/raytracing/primitives.py
def center_of_triangle(triangle):\n    \"\"\"\n    Definition to calculate center of a triangle.\n\n    Parameters\n    ----------\n    triangle      : ndarray\n                    An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.reshape((1, 3, 3))\n    center = np.mean(triangle, axis=1)\n    return center\n
"},{"location":"odak/raytracing/#odak.raytracing.closest_point_to_a_ray","title":"closest_point_to_a_ray(point, ray)","text":"

Definition to calculate the point on a ray that is closest to given point.

Parameters:

  • point \u2013
            Given point in X,Y,Z.\n
  • ray \u2013
            Given ray.\n

Returns:

  • closest_point ( ndarray ) \u2013

    Calculated closest point.

Source code in odak/tools/vector.py
def closest_point_to_a_ray(point, ray):\n    \"\"\"\n    Definition to calculate the point on a ray that is closest to given point.\n\n    Parameters\n    ----------\n    point         : list\n                    Given point in X,Y,Z.\n    ray           : ndarray\n                    Given ray.\n\n    Returns\n    ---------\n    closest_point : ndarray\n                    Calculated closest point.\n    \"\"\"\n    from odak.raytracing import propagate_a_ray\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    p0 = ray[:, 0]\n    p1 = propagate_a_ray(ray, 1.)\n    if len(p1.shape) == 2:\n        p1 = p1.reshape((1, 2, 3))\n    p1 = p1[:, 0]\n    p1 = p1.reshape(3)\n    p0 = p0.reshape(3)\n    point = point.reshape(3)\n    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)\n    closest_point = propagate_a_ray(ray, closest_distance)[0]\n    return closest_point\n
"},{"location":"odak/raytracing/#odak.raytracing.create_ray","title":"create_ray(x0y0z0, abg)","text":"

Definition to create a ray.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray.\n
  • abg \u2013
           List that contaings angles in degrees with respect to the X,Y and Z axes.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/raytracing/ray.py
def create_ray(x0y0z0, abg):\n    \"\"\"\n    Definition to create a ray.\n\n    Parameters\n    ----------\n    x0y0z0       : list\n                   List that contains X,Y and Z start locations of a ray.\n    abg          : list\n                   List that contaings angles in degrees with respect to the X,Y and Z axes.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    # Due to Python 2 -> Python 3.\n    x0, y0, z0 = x0y0z0\n    alpha, beta, gamma = abg\n    # Create a vector with the given points and angles in each direction\n    point = np.array([x0, y0, z0], dtype=np.float64)\n    alpha = np.cos(np.radians(alpha))\n    beta = np.cos(np.radians(beta))\n    gamma = np.cos(np.radians(gamma))\n    # Cosines vector.\n    cosines = np.array([alpha, beta, gamma], dtype=np.float64)\n    ray = np.array([point, cosines], dtype=np.float64)\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.create_ray_from_angles","title":"create_ray_from_angles(point, angles, mode='XYZ')","text":"

Definition to create a ray from a point and angles.

Parameters:

  • point \u2013
         Point in X,Y and Z.\n
  • angles \u2013
         Angles with X,Y,Z axes in degrees. All zeros point Z axis.\n
  • mode \u2013
         Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.\n

Returns:

  • ray ( ndarray ) \u2013

    Created ray.

Source code in odak/raytracing/ray.py
def create_ray_from_angles(point, angles, mode='XYZ'):\n    \"\"\"\n    Definition to create a ray from a point and angles.\n\n    Parameters\n    ----------\n    point      : ndarray\n                 Point in X,Y and Z.\n    angles     : ndarray\n                 Angles with X,Y,Z axes in degrees. All zeros point Z axis.\n    mode       : str\n                 Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.\n\n    Returns\n    ----------\n    ray        : ndarray\n                 Created ray.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    new_point = np.zeros(point.shape)\n    new_point[:, 2] += 5.\n    new_point = rotate_points(new_point, angles, mode=mode, offset=point[:, 0])\n    ray = create_ray_from_two_points(point, new_point)\n    if ray.shape[0] == 1:\n        ray = ray.reshape((2, 3))\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.create_ray_from_two_points","title":"create_ray_from_two_points(x0y0z0, x1y1z1)","text":"

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : list\n                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n    x1y1z1       : list\n                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)\n    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.reshape((1, 3))\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.reshape((1, 3))\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)\n    s[s == 0] = np.nan\n    cosines = np.zeros((xdiff.shape[0], 3))\n    cosines[:, 0] = xdiff/s\n    cosines[:, 1] = ydiff/s\n    cosines[:, 2] = zdiff/s\n    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    if ray.shape[0] == 1:\n        ray = ray.reshape((2, 3))\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.cylinder_function","title":"cylinder_function(point, cylinder)","text":"

Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html

Parameters:

  • cylinder \u2013
         Cylinder parameters, XYZ center and radius.\n
  • point \u2013
         Point in XYZ.\n
Return

result : float Result of the evaluation. Zero if point is on sphere.

Source code in odak/raytracing/primitives.py
def cylinder_function(point, cylinder):\n    \"\"\"\n    Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html\n\n    Parameters\n    ----------\n    cylinder   : ndarray\n                 Cylinder parameters, XYZ center and radius.\n    point      : ndarray\n                 Point in XYZ.\n\n    Return\n    ----------\n    result     : float\n                 Result of the evaluation. Zero if point is on sphere.\n    \"\"\"\n    point = np.asarray(point)\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    distance = point_to_ray_distance(\n        point,\n        np.array([cylinder[0], cylinder[1], cylinder[2]], dtype=np.float64),\n        np.array([cylinder[4], cylinder[5], cylinder[6]], dtype=np.float64)\n    )\n    r = cylinder[3]\n    result = distance - r ** 2\n    return result\n
"},{"location":"odak/raytracing/#odak.raytracing.define_circle","title":"define_circle(center, radius, angles)","text":"

Definition to describe a circle in a single variable packed form.

Parameters:

  • center \u2013
      Center of a circle to be defined.\n
  • radius \u2013
      Radius of a circle to be defined.\n
  • angles \u2013
      Angular tilt of a circle.\n

Returns:

  • circle ( list ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_circle(center, radius, angles):\n    \"\"\"\n    Definition to describe a circle in a single variable packed form.\n\n    Parameters\n    ----------\n    center  : float\n              Center of a circle to be defined.\n    radius  : float\n              Radius of a circle to be defined.\n    angles  : float\n              Angular tilt of a circle.\n\n    Returns\n    ----------\n    circle  : list\n              Single variable packed form.\n    \"\"\"\n    points = define_plane(center, angles=angles)\n    circle = [\n        points,\n        center,\n        radius\n    ]\n    return circle\n
"},{"location":"odak/raytracing/#odak.raytracing.define_cylinder","title":"define_cylinder(center, radius, rotation=[0.0, 0.0, 0.0])","text":"

Definition to define a cylinder

Parameters:

  • center \u2013
         Center of a cylinder in X,Y,Z.\n
  • radius \u2013
         Radius of a cylinder along X axis.\n
  • rotation \u2013
         Direction angles in degrees for the orientation of a cylinder.\n

Returns:

  • cylinder ( ndarray ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_cylinder(center, radius, rotation=[0., 0., 0.]):\n    \"\"\"\n    Definition to define a cylinder\n\n    Parameters\n    ----------\n    center     : ndarray\n                 Center of a cylinder in X,Y,Z.\n    radius     : float\n                 Radius of a cylinder along X axis.\n    rotation   : list\n                 Direction angles in degrees for the orientation of a cylinder.\n\n    Returns\n    ----------\n    cylinder   : ndarray\n                 Single variable packed form.\n    \"\"\"\n    cylinder_ray = create_ray_from_angles(\n        np.asarray(center), np.asarray(rotation))\n    cylinder = np.array(\n        [\n            center[0],\n            center[1],\n            center[2],\n            radius,\n            center[0]+cylinder_ray[1, 0],\n            center[1]+cylinder_ray[1, 1],\n            center[2]+cylinder_ray[1, 2]\n        ],\n        dtype=np.float64\n    )\n    return cylinder\n
"},{"location":"odak/raytracing/#odak.raytracing.define_plane","title":"define_plane(point, angles=[0.0, 0.0, 0.0])","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • point \u2013
           A point that is at the center of a plane.\n
  • angles \u2013
           Rotation angles in degrees.\n

Returns:

  • plane ( ndarray ) \u2013

    Points defining plane.

Source code in odak/raytracing/primitives.py
def define_plane(point, angles=[0., 0., 0.]):\n    \"\"\" \n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    point        : ndarray\n                   A point that is at the center of a plane.\n    angles       : list\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    plane        : ndarray\n                   Points defining plane.\n    \"\"\"\n    plane = np.array([\n        [10., 10., 0.],\n        [0., 10., 0.],\n        [0.,  0., 0.]\n    ], dtype=np.float64)\n    point = np.asarray(point)\n    for i in range(0, plane.shape[0]):\n        plane[i], _, _, _ = rotate_point(plane[i], angles=angles)\n        plane[i] = plane[i]+point\n    return plane\n
"},{"location":"odak/raytracing/#odak.raytracing.define_sphere","title":"define_sphere(center, radius)","text":"

Definition to define a sphere.

Parameters:

  • center \u2013
         Center of a sphere in X,Y,Z.\n
  • radius \u2013
         Radius of a sphere.\n

Returns:

  • sphere ( ndarray ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_sphere(center, radius):\n    \"\"\"\n    Definition to define a sphere.\n\n    Parameters\n    ----------\n    center     : ndarray\n                 Center of a sphere in X,Y,Z.\n    radius     : float\n                 Radius of a sphere.\n\n    Returns\n    ----------\n    sphere     : ndarray\n                 Single variable packed form.\n    \"\"\"\n    sphere = np.array(\n        [center[0], center[1], center[2], radius], dtype=np.float64)\n    return sphere\n
"},{"location":"odak/raytracing/#odak.raytracing.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( float ) \u2013

    Distance in between given two points.

Source code in odak/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : list\n                  First point in X,Y,Z.\n    point2      : list\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : float\n                  Distance in between given two points.\n    \"\"\"\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = np.sqrt(np.sum((point1-point2)**2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))\n    return distance\n
"},{"location":"odak/raytracing/#odak.raytracing.find_nearest_points","title":"find_nearest_points(ray0, ray1)","text":"

Find the nearest points on given rays with respect to the other ray.

Parameters:

  • ray0 \u2013
         A ray.\n
  • ray1 \u2013
         A ray.\n

Returns:

  • c0 ( ndarray ) \u2013

    Closest point on ray0.

  • c1 ( ndarray ) \u2013

    Closest point on ray1.

Source code in odak/raytracing/ray.py
def find_nearest_points(ray0, ray1):\n    \"\"\"\n    Find the nearest points on given rays with respect to the other ray.\n\n    Parameters\n    ----------\n    ray0       : ndarray\n                 A ray.\n    ray1       : ndarray\n                 A ray.\n\n    Returns\n    ----------\n    c0         : ndarray\n                 Closest point on ray0.\n    c1         : ndarray\n                 Closest point on ray1.\n    \"\"\"\n    p0 = ray0[0].reshape(3,)\n    d0 = ray0[1].reshape(3,)\n    p1 = ray1[0].reshape(3,)\n    d1 = ray1[1].reshape(3,)\n    n = np.cross(d0, d1)\n    if np.all(n) == 0:\n        point, distances = calculate_intersection_of_two_rays(ray0, ray1)\n        c0 = c1 = point\n    else:\n        n0 = np.cross(d0, n)\n        n1 = np.cross(d1, n)\n        c0 = p0+(np.dot((p1-p0), n1)/np.dot(d0, n1))*d0\n        c1 = p1+(np.dot((p0-p1), n0)/np.dot(d1, n0))*d1\n    return c0, c1\n
"},{"location":"odak/raytracing/#odak.raytracing.get_cylinder_normal","title":"get_cylinder_normal(point, cylinder)","text":"

Parameters:

  • point \u2013
            Point on a cylinder defined in X,Y,Z.\n

Returns:

  • normal_vector ( ndarray ) \u2013

    Normal vector.

Source code in odak/raytracing/boundary.py
def get_cylinder_normal(point, cylinder):\n    \"\"\"\n    Parameters\n    ----------\n    point         : ndarray\n                    Point on a cylinder defined in X,Y,Z.\n\n    Returns\n    ----------\n    normal_vector : ndarray\n                    Normal vector.\n    \"\"\"\n    cylinder_ray = create_ray_from_two_points(cylinder[0:3], cylinder[4:7])\n    closest_point = closest_point_to_a_ray(\n        point,\n        cylinder_ray\n    )\n    normal_vector = create_ray_from_two_points(closest_point, point)\n    return normal_vector\n
"},{"location":"odak/raytracing/#odak.raytracing.get_sphere_normal","title":"get_sphere_normal(point, sphere)","text":"

Definition to get a normal of a point on a given sphere.

Parameters:

  • point \u2013
            Point on sphere in X,Y,Z.\n
  • sphere \u2013
            Center defined in X,Y,Z and radius.\n

Returns:

  • normal_vector ( ndarray ) \u2013

    Normal vector.

Source code in odak/raytracing/boundary.py
def get_sphere_normal(point, sphere):\n    \"\"\"\n    Definition to get a normal of a point on a given sphere.\n\n    Parameters\n    ----------\n    point         : ndarray\n                    Point on sphere in X,Y,Z.\n    sphere        : ndarray\n                    Center defined in X,Y,Z and radius.\n\n    Returns\n    ----------\n    normal_vector : ndarray\n                    Normal vector.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    normal_vector = create_ray_from_two_points(point, sphere[0:3])\n    return normal_vector\n
"},{"location":"odak/raytracing/#odak.raytracing.get_triangle_normal","title":"get_triangle_normal(triangle, triangle_center=None)","text":"

Definition to calculate surface normal of a triangle.

Parameters:

  • triangle \u2013
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n
  • triangle_center (ndarray, default: None ) \u2013
              Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

Source code in odak/raytracing/boundary.py
def get_triangle_normal(triangle, triangle_center=None):\n    \"\"\"\n    Definition to calculate surface normal of a triangle.\n\n    Parameters\n    ----------\n    triangle        : ndarray\n                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n    triangle_center : ndarray\n                      Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n\n    Returns\n    ----------\n    normal          : ndarray\n                      Surface normal at the point of intersection.\n    \"\"\"\n    triangle = np.asarray(triangle)\n    if len(triangle.shape) == 2:\n        triangle = triangle.reshape((1, 3, 3))\n    normal = np.zeros((triangle.shape[0], 2, 3))\n    direction = np.cross(\n        triangle[:, 0]-triangle[:, 1], triangle[:, 2]-triangle[:, 1])\n    if type(triangle_center) == type(None):\n        normal[:, 0] = center_of_triangle(triangle)\n    else:\n        normal[:, 0] = triangle_center\n    normal[:, 1] = direction/np.sum(direction, axis=1)[0]\n    if normal.shape[0] == 1:\n        normal = normal.reshape((2, 3))\n    return normal\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_parametric","title":"intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=1e-08, iter_no_limit=100000)","text":"

Definition to intersect a ray with a parametric surface.

Parameters:

  • ray \u2013
                      Ray.\n
  • parametric_surface \u2013
                      Parameters of the surfaces.\n
  • surface_function \u2013
                      Function to evaluate a point against a surface.\n
  • surface_normal_function (function) \u2013
                      Function to calculate surface normal for a given point on a surface.\n
  • target_error \u2013
                      Target error that defines the precision.\n
  • iter_no_limit \u2013
                      Maximum number of iterations.\n

Returns:

  • distance ( float ) \u2013

    Propagation distance.

  • normal ( ndarray ) \u2013

    Ray that defines a surface normal for the intersection.

Source code in odak/raytracing/boundary.py
def intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=0.00000001, iter_no_limit=100000):\n    \"\"\"\n    Definition to intersect a ray with a parametric surface.\n\n    Parameters\n    ----------\n    ray                     : ndarray\n                              Ray.\n    parametric_surface      : ndarray\n                              Parameters of the surfaces.\n    surface_function        : function\n                              Function to evaluate a point against a surface.\n    surface_normal_function : function\n                              Function to calculate surface normal for a given point on a surface.\n    target_error            : float\n                              Target error that defines the precision.  \n    iter_no_limit           : int\n                              Maximum number of iterations.\n\n    Returns\n    ----------\n    distance                : float\n                              Propagation distance.\n    normal                  : ndarray\n                              Ray that defines a surface normal for the intersection.\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    error = [150, 100]\n    distance = [0, 0.1]\n    iter_no = 0\n    while np.abs(np.max(np.asarray(error[1]))) > target_error:\n        error[1], point = intersection_kernel_for_parametric_surfaces(\n            distance[1],\n            ray,\n            parametric_surface,\n            surface_function\n        )\n        distance, error = propagate_parametric_intersection_error(\n            distance,\n            error\n        )\n        iter_no += 1\n        if iter_no > iter_no_limit:\n            return False, False\n        if np.isnan(np.sum(point)):\n            return False, False\n    normal = surface_normal_function(\n        point,\n        parametric_surface\n    )\n    return distance[1], normal\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_w_circle","title":"intersect_w_circle(ray, circle)","text":"

Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.

Parameters:

  • ray \u2013
           A vector/ray.\n
  • circle \u2013
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/raytracing/boundary.py
def intersect_w_circle(ray, circle):\n    \"\"\"\n    Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.\n\n    Parameters\n    ----------\n    ray          : ndarray\n                   A vector/ray.\n    circle       : list\n                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, circle[0])\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])\n    distance[np.nonzero(distance_to_center > circle[2])] = 0\n    if len(ray.shape) == 2:\n        normal = normal.reshape((2, 3))\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_w_cylinder","title":"intersect_w_cylinder(ray, cylinder)","text":"

Definition to intersect a ray with a cylinder.

Parameters:

  • ray \u2013
         A ray definition.\n
  • cylinder \u2013
         A cylinder defined with a center in XYZ and radius of curvature.\n

Returns:

  • normal ( ndarray ) \u2013

    A ray defining surface normal at the point of intersection.

  • distance ( float ) \u2013

    Total optical propagation distance.

Source code in odak/raytracing/boundary.py
def intersect_w_cylinder(ray, cylinder):\n    \"\"\"\n    Definition to intersect a ray with a cylinder.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray definition.\n    cylinder   : ndarray\n                 A cylinder defined with a center in XYZ and radius of curvature.\n\n    Returns\n    ----------\n    normal     : ndarray\n                 A ray defining surface normal at the point of intersection.\n    distance   : float\n                 Total optical propagation distance.\n    \"\"\"\n    distance, normal = intersect_parametric(\n        ray,\n        cylinder,\n        cylinder_function,\n        get_cylinder_normal\n    )\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_w_sphere","title":"intersect_w_sphere(ray, sphere)","text":"

Definition to intersect a ray with a sphere.

Parameters:

  • ray \u2013
         A ray definition.\n
  • sphere \u2013
         A sphere defined with a center in XYZ and radius of curvature.\n

Returns:

  • normal ( ndarray ) \u2013

    A ray defining surface normal at the point of intersection.

  • distance ( float ) \u2013

    Total optical propagation distance.

Source code in odak/raytracing/boundary.py
def intersect_w_sphere(ray, sphere):\n    \"\"\"\n    Definition to intersect a ray with a sphere.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray definition.\n    sphere     : ndarray\n                 A sphere defined with a center in XYZ and radius of curvature.\n\n    Returns\n    ----------\n    normal     : ndarray\n                 A ray defining surface normal at the point of intersection.\n    distance   : float\n                 Total optical propagation distance.\n    \"\"\"\n    distance, normal = intersect_parametric(\n        ray,\n        sphere,\n        sphere_function,\n        get_sphere_normal\n    )\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_w_surface","title":"intersect_w_surface(ray, points)","text":"

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

Parameters:

  • ray \u2013
           A vector/ray.\n
  • points \u2013
           Set of points in X,Y and Z to define a planar surface.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface.

Source code in odak/raytracing/boundary.py
def intersect_w_surface(ray, points):\n    \"\"\"\n    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html\n\n    Parameters\n    ----------\n    ray          : ndarray\n                   A vector/ray.\n    points       : ndarray\n                   Set of points in X,Y and Z to define a planar surface.\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between starting point of a ray with it's intersection with a planar surface.\n    \"\"\"\n    points = np.asarray(points)\n    normal = get_triangle_normal(points)\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    if len(points) == 2:\n        points = points.reshape((1, 3, 3))\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    f = normal[:, 0]-ray[:, 0]\n    distance = np.dot(normal[:, 1], f.T)/np.dot(normal[:, 1], ray[:, 1].T)\n    n = np.int64(np.amax(np.array([ray.shape[0], normal.shape[0]])))\n    normal = np.zeros((n, 2, 3))\n    normal[:, 0] = ray[:, 0]+distance.T*ray[:, 1]\n    distance = np.abs(distance)\n    if normal.shape[0] == 1:\n        normal = normal.reshape((2, 3))\n        distance = distance.reshape((1))\n    if distance.shape[0] == 1 and len(distance.shape) > 1:\n        distance = distance.reshape((distance.shape[1]))\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.intersect_w_triangle","title":"intersect_w_triangle(ray, triangle)","text":"

Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.

Parameters:

  • ray \u2013
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/raytracing/boundary.py
def intersect_w_triangle(ray, triangle):\n    \"\"\"\n    Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, triangle)\n    if is_it_on_triangle(normal[0], triangle[0], triangle[1], triangle[2]) == False:\n        return 0, 0\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.intersection_kernel_for_parametric_surfaces","title":"intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function)","text":"

Definition for the intersection kernel when dealing with parametric surfaces.

Parameters:

  • distance \u2013
                 Distance.\n
  • ray \u2013
                 Ray.\n
  • parametric_surface (ndarray) \u2013
                 Array that defines a parametric surface.\n
  • surface_function \u2013
                 Function to evaluate a point against a parametric surface.\n

Returns:

  • point ( ndarray ) \u2013

    Location in X,Y,Z after propagation.

  • error ( float ) \u2013

    Error.

Source code in odak/raytracing/boundary.py
def intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function):\n    \"\"\"\n    Definition for the intersection kernel when dealing with parametric surfaces.\n\n    Parameters\n    ----------\n    distance           : float\n                         Distance.\n    ray                : ndarray\n                         Ray.\n    parametric_surface : ndarray\n                         Array that defines a parametric surface.\n    surface_function   : ndarray\n                         Function to evaluate a point against a parametric surface.\n\n    Returns\n    ----------\n    point              : ndarray\n                         Location in X,Y,Z after propagation.\n    error              : float\n                         Error.\n    \"\"\"\n    new_ray = propagate_a_ray(ray, distance)\n    if len(new_ray) == 2:\n        new_ray = new_ray.reshape((1, 2, 3))\n    point = new_ray[:, 0]\n    error = surface_function(point, parametric_surface)\n    return error, point\n
"},{"location":"odak/raytracing/#odak.raytracing.is_it_on_triangle","title":"is_it_on_triangle(pointtocheck, point0, point1, point2)","text":"

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.

Parameters:

  • pointtocheck \u2013
            Point to check.\n
  • point0 \u2013
            First point of a triangle.\n
  • point1 \u2013
            Second point of a triangle.\n
  • point2 \u2013
            Third point of a triangle.\n
Source code in odak/raytracing/primitives.py
def is_it_on_triangle(pointtocheck, point0, point1, point2):\n    \"\"\"\n    Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.\n\n    Parameters\n    ----------\n    pointtocheck  : list\n                    Point to check.\n    point0        : list\n                    First point of a triangle.\n    point1        : list\n                    Second point of a triangle.\n    point2        : list\n                    Third point of a triangle.\n    \"\"\"\n    # point0, point1 and point2 are the corners of the triangle.\n    pointtocheck = np.asarray(pointtocheck).reshape(3)\n    point0 = np.asarray(point0)\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    side0 = same_side(pointtocheck, point0, point1, point2)\n    side1 = same_side(pointtocheck, point1, point0, point2)\n    side2 = same_side(pointtocheck, point2, point0, point1)\n    if side0 == True and side1 == True and side2 == True:\n        return True\n    return False\n
"},{"location":"odak/raytracing/#odak.raytracing.point_to_ray_distance","title":"point_to_ray_distance(point, ray_point_0, ray_point_1)","text":"

Definition to find point's closest distance to a line represented with two points.

Parameters:

  • point \u2013
          Point to be tested.\n
  • ray_point_0 (ndarray) \u2013
          First point to represent a line.\n
  • ray_point_1 (ndarray) \u2013
          Second point to represent a line.\n

Returns:

  • distance ( float ) \u2013

    Calculated distance.

Source code in odak/tools/vector.py
def point_to_ray_distance(point, ray_point_0, ray_point_1):\n    \"\"\"\n    Definition to find point's closest distance to a line represented with two points.\n\n    Parameters\n    ----------\n    point       : ndarray\n                  Point to be tested.\n    ray_point_0 : ndarray\n                  First point to represent a line.\n    ray_point_1 : ndarray\n                  Second point to represent a line.\n\n    Returns\n    ----------\n    distance    : float\n                  Calculated distance.\n    \"\"\"\n    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))\n                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)\n    return distance\n
"},{"location":"odak/raytracing/#odak.raytracing.propagate_a_ray","title":"propagate_a_ray(ray, distance)","text":"

Definition to propagate a ray at a certain given distance.

Parameters:

  • ray \u2013
         A ray.\n
  • distance \u2013
         Distance.\n

Returns:

  • new_ray ( ndarray ) \u2013

    Propagated ray.

Source code in odak/raytracing/ray.py
def propagate_a_ray(ray, distance):\n    \"\"\"\n    Definition to propagate a ray at a certain given distance.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray.\n    distance   : float\n                 Distance.\n\n    Returns\n    ----------\n    new_ray    : ndarray\n                 Propagated ray.\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    new_ray = np.copy(ray)\n    new_ray[:, 0, 0] = distance*new_ray[:, 1, 0] + new_ray[:, 0, 0]\n    new_ray[:, 0, 1] = distance*new_ray[:, 1, 1] + new_ray[:, 0, 1]\n    new_ray[:, 0, 2] = distance*new_ray[:, 1, 2] + new_ray[:, 0, 2]\n    if new_ray.shape[0] == 1:\n        new_ray = new_ray.reshape((2, 3))\n    return new_ray\n
"},{"location":"odak/raytracing/#odak.raytracing.propagate_parametric_intersection_error","title":"propagate_parametric_intersection_error(distance, error)","text":"

Definition to propagate the error in parametric intersection to find the next distance to try.

Parameters:

  • distance \u2013
           List that contains the new and the old distance.\n
  • error \u2013
           List that contains the new and the old error.\n

Returns:

  • distance ( list ) \u2013

    New distance.

  • error ( list ) \u2013

    New error.

Source code in odak/raytracing/boundary.py
def propagate_parametric_intersection_error(distance, error):\n    \"\"\"\n    Definition to propagate the error in parametric intersection to find the next distance to try.\n\n    Parameters\n    ----------\n    distance     : list\n                   List that contains the new and the old distance.\n    error        : list\n                   List that contains the new and the old error.\n\n    Returns\n    ----------\n    distance     : list\n                   New distance.\n    error        : list\n                   New error.\n    \"\"\"\n    new_distance = distance[1]-error[1] * \\\n        (distance[1]-distance[0])/(error[1]-error[0])\n    distance[0] = distance[1]\n    distance[1] = np.abs(new_distance)\n    error[0] = error[1]\n    return distance, error\n
"},{"location":"odak/raytracing/#odak.raytracing.reflect","title":"reflect(input_ray, normal)","text":"

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • input_ray \u2013
           A vector/ray (2x3). It can also be a list of rays (nx2x3).\n
  • normal \u2013
           A surface normal (2x3). It also be a list of normals (nx2x3).\n

Returns:

  • output_ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a reflected ray.

Source code in odak/raytracing/boundary.py
def reflect(input_ray, normal):\n    \"\"\" \n    Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n    Parameters\n    ----------\n    input_ray    : ndarray\n                   A vector/ray (2x3). It can also be a list of rays (nx2x3).\n    normal       : ndarray\n                   A surface normal (2x3). It also be a list of normals (nx2x3).\n\n    Returns\n    ----------\n    output_ray   : ndarray\n                   Array that contains starting points and cosines of a reflected ray.\n    \"\"\"\n    input_ray = np.asarray(input_ray)\n    normal = np.asarray(normal)\n    if len(input_ray.shape) == 2:\n        input_ray = input_ray.reshape((1, 2, 3))\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    mu = 1\n    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2\n    a = mu * (input_ray[:, 1, 0]*normal[:, 1, 0]\n              + input_ray[:, 1, 1]*normal[:, 1, 1]\n              + input_ray[:, 1, 2]*normal[:, 1, 2]) / div\n    n = np.int64(np.amax(np.array([normal.shape[0], input_ray.shape[0]])))\n    output_ray = np.zeros((n, 2, 3))\n    output_ray[:, 0] = normal[:, 0]\n    output_ray[:, 1] = input_ray[:, 1]-2*a*normal[:, 1]\n    if output_ray.shape[0] == 1:\n        output_ray = output_ray.reshape((2, 3))\n    return output_ray\n
"},{"location":"odak/raytracing/#odak.raytracing.rotate_point","title":"rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

  • rotx ( ndarray ) \u2013

    Rotation matrix along X axis.

  • roty ( ndarray ) \u2013

    Rotation matrix along Y axis.

  • rotz ( ndarray ) \u2013

    Rotation matrix along Z axis.

Source code in odak/tools/transformation.py
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : ndarray\n                   A point.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation\n    rotx         : ndarray\n                   Rotation matrix along X axis.\n    roty         : ndarray\n                   Rotation matrix along Y axis.\n    rotz         : ndarray\n                   Rotation matrix along Z axis.\n    \"\"\"\n    point = np.asarray(point)\n    point -= np.asarray(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))\n    result += np.asarray(origin)\n    result += np.asarray(offset)\n    return result, rotx, roty, rotz\n
"},{"location":"odak/raytracing/#odak.raytracing.rotate_points","title":"rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate points.

Parameters:

  • points \u2013
           Points.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

Source code in odak/tools/transformation.py
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate points.\n\n    Parameters\n    ----------\n    points       : ndarray\n                   Points.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation   \n    \"\"\"\n    points = np.asarray(points)\n    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:\n        result = np.array(offset) + points\n        return result\n    points -= np.array(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T\n    result += np.array(origin)\n    result += np.array(offset)\n    return result\n
"},{"location":"odak/raytracing/#odak.raytracing.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = np.subtract(b, a)\n    p1a = np.subtract(p1, a)\n    p2a = np.subtract(p2, a)\n    cp1 = np.cross(ba, p1a)\n    cp2 = np.cross(ba, p2a)\n    test = np.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/raytracing/#odak.raytracing.sphere_function","title":"sphere_function(point, sphere)","text":"

Definition of a sphere function. Evaluate a point against a sphere function.

Parameters:

  • sphere \u2013
         Sphere parameters, XYZ center and radius.\n
  • point \u2013
         Point in XYZ.\n
Return

result : float Result of the evaluation. Zero if point is on sphere.

Source code in odak/raytracing/primitives.py
def sphere_function(point, sphere):\n    \"\"\"\n    Definition of a sphere function. Evaluate a point against a sphere function.\n\n    Parameters\n    ----------\n    sphere     : ndarray\n                 Sphere parameters, XYZ center and radius.\n    point      : ndarray\n                 Point in XYZ.\n\n    Return\n    ----------\n    result     : float\n                 Result of the evaluation. Zero if point is on sphere.\n    \"\"\"\n    point = np.asarray(point)\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    result = (point[:, 0]-sphere[0])**2 + (point[:, 1]-sphere[1]\n                                           )**2 + (point[:, 2]-sphere[2])**2 - sphere[3]**2\n    return result\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.get_cylinder_normal","title":"get_cylinder_normal(point, cylinder)","text":"

Parameters:

  • point \u2013
            Point on a cylinder defined in X,Y,Z.\n

Returns:

  • normal_vector ( ndarray ) \u2013

    Normal vector.

Source code in odak/raytracing/boundary.py
def get_cylinder_normal(point, cylinder):\n    \"\"\"\n    Parameters\n    ----------\n    point         : ndarray\n                    Point on a cylinder defined in X,Y,Z.\n\n    Returns\n    ----------\n    normal_vector : ndarray\n                    Normal vector.\n    \"\"\"\n    cylinder_ray = create_ray_from_two_points(cylinder[0:3], cylinder[4:7])\n    closest_point = closest_point_to_a_ray(\n        point,\n        cylinder_ray\n    )\n    normal_vector = create_ray_from_two_points(closest_point, point)\n    return normal_vector\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.get_sphere_normal","title":"get_sphere_normal(point, sphere)","text":"

Definition to get a normal of a point on a given sphere.

Parameters:

  • point \u2013
            Point on sphere in X,Y,Z.\n
  • sphere \u2013
            Center defined in X,Y,Z and radius.\n

Returns:

  • normal_vector ( ndarray ) \u2013

    Normal vector.

Source code in odak/raytracing/boundary.py
def get_sphere_normal(point, sphere):\n    \"\"\"\n    Definition to get a normal of a point on a given sphere.\n\n    Parameters\n    ----------\n    point         : ndarray\n                    Point on sphere in X,Y,Z.\n    sphere        : ndarray\n                    Center defined in X,Y,Z and radius.\n\n    Returns\n    ----------\n    normal_vector : ndarray\n                    Normal vector.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    normal_vector = create_ray_from_two_points(point, sphere[0:3])\n    return normal_vector\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.get_triangle_normal","title":"get_triangle_normal(triangle, triangle_center=None)","text":"

Definition to calculate surface normal of a triangle.

Parameters:

  • triangle \u2013
              Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n
  • triangle_center (ndarray, default: None ) \u2013
              Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

Source code in odak/raytracing/boundary.py
def get_triangle_normal(triangle, triangle_center=None):\n    \"\"\"\n    Definition to calculate surface normal of a triangle.\n\n    Parameters\n    ----------\n    triangle        : ndarray\n                      Set of points in X,Y and Z to define a planar surface (3,3). It can also be list of triangles (mx3x3).\n    triangle_center : ndarray\n                      Center point of the given triangle. See odak.raytracing.center_of_triangle for more. In many scenarios you can accelerate things by precomputing triangle centers.\n\n    Returns\n    ----------\n    normal          : ndarray\n                      Surface normal at the point of intersection.\n    \"\"\"\n    triangle = np.asarray(triangle)\n    if len(triangle.shape) == 2:\n        triangle = triangle.reshape((1, 3, 3))\n    normal = np.zeros((triangle.shape[0], 2, 3))\n    direction = np.cross(\n        triangle[:, 0]-triangle[:, 1], triangle[:, 2]-triangle[:, 1])\n    if type(triangle_center) == type(None):\n        normal[:, 0] = center_of_triangle(triangle)\n    else:\n        normal[:, 0] = triangle_center\n    normal[:, 1] = direction/np.sum(direction, axis=1)[0]\n    if normal.shape[0] == 1:\n        normal = normal.reshape((2, 3))\n    return normal\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_parametric","title":"intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=1e-08, iter_no_limit=100000)","text":"

Definition to intersect a ray with a parametric surface.

Parameters:

  • ray \u2013
                      Ray.\n
  • parametric_surface \u2013
                      Parameters of the surfaces.\n
  • surface_function \u2013
                      Function to evaluate a point against a surface.\n
  • surface_normal_function (function) \u2013
                      Function to calculate surface normal for a given point on a surface.\n
  • target_error \u2013
                      Target error that defines the precision.\n
  • iter_no_limit \u2013
                      Maximum number of iterations.\n

Returns:

  • distance ( float ) \u2013

    Propagation distance.

  • normal ( ndarray ) \u2013

    Ray that defines a surface normal for the intersection.

Source code in odak/raytracing/boundary.py
def intersect_parametric(ray, parametric_surface, surface_function, surface_normal_function, target_error=0.00000001, iter_no_limit=100000):\n    \"\"\"\n    Definition to intersect a ray with a parametric surface.\n\n    Parameters\n    ----------\n    ray                     : ndarray\n                              Ray.\n    parametric_surface      : ndarray\n                              Parameters of the surfaces.\n    surface_function        : function\n                              Function to evaluate a point against a surface.\n    surface_normal_function : function\n                              Function to calculate surface normal for a given point on a surface.\n    target_error            : float\n                              Target error that defines the precision.  \n    iter_no_limit           : int\n                              Maximum number of iterations.\n\n    Returns\n    ----------\n    distance                : float\n                              Propagation distance.\n    normal                  : ndarray\n                              Ray that defines a surface normal for the intersection.\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    error = [150, 100]\n    distance = [0, 0.1]\n    iter_no = 0\n    while np.abs(np.max(np.asarray(error[1]))) > target_error:\n        error[1], point = intersection_kernel_for_parametric_surfaces(\n            distance[1],\n            ray,\n            parametric_surface,\n            surface_function\n        )\n        distance, error = propagate_parametric_intersection_error(\n            distance,\n            error\n        )\n        iter_no += 1\n        if iter_no > iter_no_limit:\n            return False, False\n        if np.isnan(np.sum(point)):\n            return False, False\n    normal = surface_normal_function(\n        point,\n        parametric_surface\n    )\n    return distance[1], normal\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_w_circle","title":"intersect_w_circle(ray, circle)","text":"

Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.

Parameters:

  • ray \u2013
           A vector/ray.\n
  • circle \u2013
           A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/raytracing/boundary.py
def intersect_w_circle(ray, circle):\n    \"\"\"\n    Definition to find intersection point of a ray with a circle. Returns False for each variable if the ray doesn't intersect with a given circle. Returns distance as zero if there isn't an intersection.\n\n    Parameters\n    ----------\n    ray          : ndarray\n                   A vector/ray.\n    circle       : list\n                   A list that contains (0) Set of points in X,Y and Z to define plane of a circle, (1) circle center, and (2) circle radius.\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, circle[0])\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    distance_to_center = distance_between_two_points(normal[:, 0], circle[1])\n    distance[np.nonzero(distance_to_center > circle[2])] = 0\n    if len(ray.shape) == 2:\n        normal = normal.reshape((2, 3))\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_w_cylinder","title":"intersect_w_cylinder(ray, cylinder)","text":"

Definition to intersect a ray with a cylinder.

Parameters:

  • ray \u2013
         A ray definition.\n
  • cylinder \u2013
         A cylinder defined with a center in XYZ and radius of curvature.\n

Returns:

  • normal ( ndarray ) \u2013

    A ray defining surface normal at the point of intersection.

  • distance ( float ) \u2013

    Total optical propagation distance.

Source code in odak/raytracing/boundary.py
def intersect_w_cylinder(ray, cylinder):\n    \"\"\"\n    Definition to intersect a ray with a cylinder.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray definition.\n    cylinder   : ndarray\n                 A cylinder defined with a center in XYZ and radius of curvature.\n\n    Returns\n    ----------\n    normal     : ndarray\n                 A ray defining surface normal at the point of intersection.\n    distance   : float\n                 Total optical propagation distance.\n    \"\"\"\n    distance, normal = intersect_parametric(\n        ray,\n        cylinder,\n        cylinder_function,\n        get_cylinder_normal\n    )\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_w_sphere","title":"intersect_w_sphere(ray, sphere)","text":"

Definition to intersect a ray with a sphere.

Parameters:

  • ray \u2013
         A ray definition.\n
  • sphere \u2013
         A sphere defined with a center in XYZ and radius of curvature.\n

Returns:

  • normal ( ndarray ) \u2013

    A ray defining surface normal at the point of intersection.

  • distance ( float ) \u2013

    Total optical propagation distance.

Source code in odak/raytracing/boundary.py
def intersect_w_sphere(ray, sphere):\n    \"\"\"\n    Definition to intersect a ray with a sphere.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray definition.\n    sphere     : ndarray\n                 A sphere defined with a center in XYZ and radius of curvature.\n\n    Returns\n    ----------\n    normal     : ndarray\n                 A ray defining surface normal at the point of intersection.\n    distance   : float\n                 Total optical propagation distance.\n    \"\"\"\n    distance, normal = intersect_parametric(\n        ray,\n        sphere,\n        sphere_function,\n        get_sphere_normal\n    )\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_w_surface","title":"intersect_w_surface(ray, points)","text":"

Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html

Parameters:

  • ray \u2013
           A vector/ray.\n
  • points \u2013
           Set of points in X,Y and Z to define a planar surface.\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between starting point of a ray with it's intersection with a planar surface.

Source code in odak/raytracing/boundary.py
def intersect_w_surface(ray, points):\n    \"\"\"\n    Definition to find intersection point inbetween a surface and a ray. For more see: http://geomalgorithms.com/a06-_intersect-2.html\n\n    Parameters\n    ----------\n    ray          : ndarray\n                   A vector/ray.\n    points       : ndarray\n                   Set of points in X,Y and Z to define a planar surface.\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between starting point of a ray with it's intersection with a planar surface.\n    \"\"\"\n    points = np.asarray(points)\n    normal = get_triangle_normal(points)\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    if len(points) == 2:\n        points = points.reshape((1, 3, 3))\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    f = normal[:, 0]-ray[:, 0]\n    distance = np.dot(normal[:, 1], f.T)/np.dot(normal[:, 1], ray[:, 1].T)\n    n = np.int64(np.amax(np.array([ray.shape[0], normal.shape[0]])))\n    normal = np.zeros((n, 2, 3))\n    normal[:, 0] = ray[:, 0]+distance.T*ray[:, 1]\n    distance = np.abs(distance)\n    if normal.shape[0] == 1:\n        normal = normal.reshape((2, 3))\n        distance = distance.reshape((1))\n    if distance.shape[0] == 1 and len(distance.shape) > 1:\n        distance = distance.reshape((distance.shape[1]))\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersect_w_triangle","title":"intersect_w_triangle(ray, triangle)","text":"

Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.

Parameters:

  • ray \u2013
           A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n
  • triangle \u2013
           Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n

Returns:

  • normal ( ndarray ) \u2013

    Surface normal at the point of intersection.

  • distance ( float ) \u2013

    Distance in between a starting point of a ray and the intersection point with a given triangle.

Source code in odak/raytracing/boundary.py
def intersect_w_triangle(ray, triangle):\n    \"\"\"\n    Definition to find intersection point of a ray with a triangle. Returns False for each variable if the ray doesn't intersect with a given triangle.\n\n    Parameters\n    ----------\n    ray          : torch.tensor\n                   A vector/ray (2 x 3). It can also be a list of rays (n x 2 x 3).\n    triangle     : torch.tensor\n                   Set of points in X,Y and Z to define a planar surface. It can also be a list of triangles (m x 3 x 3).\n\n    Returns\n    ----------\n    normal       : ndarray\n                   Surface normal at the point of intersection.\n    distance     : float\n                   Distance in between a starting point of a ray and the intersection point with a given triangle.\n    \"\"\"\n    normal, distance = intersect_w_surface(ray, triangle)\n    if is_it_on_triangle(normal[0], triangle[0], triangle[1], triangle[2]) == False:\n        return 0, 0\n    return normal, distance\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.intersection_kernel_for_parametric_surfaces","title":"intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function)","text":"

Definition for the intersection kernel when dealing with parametric surfaces.

Parameters:

  • distance \u2013
                 Distance.\n
  • ray \u2013
                 Ray.\n
  • parametric_surface (ndarray) \u2013
                 Array that defines a parametric surface.\n
  • surface_function \u2013
                 Function to evaluate a point against a parametric surface.\n

Returns:

  • point ( ndarray ) \u2013

    Location in X,Y,Z after propagation.

  • error ( float ) \u2013

    Error.

Source code in odak/raytracing/boundary.py
def intersection_kernel_for_parametric_surfaces(distance, ray, parametric_surface, surface_function):\n    \"\"\"\n    Definition for the intersection kernel when dealing with parametric surfaces.\n\n    Parameters\n    ----------\n    distance           : float\n                         Distance.\n    ray                : ndarray\n                         Ray.\n    parametric_surface : ndarray\n                         Array that defines a parametric surface.\n    surface_function   : ndarray\n                         Function to evaluate a point against a parametric surface.\n\n    Returns\n    ----------\n    point              : ndarray\n                         Location in X,Y,Z after propagation.\n    error              : float\n                         Error.\n    \"\"\"\n    new_ray = propagate_a_ray(ray, distance)\n    if len(new_ray) == 2:\n        new_ray = new_ray.reshape((1, 2, 3))\n    point = new_ray[:, 0]\n    error = surface_function(point, parametric_surface)\n    return error, point\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.propagate_parametric_intersection_error","title":"propagate_parametric_intersection_error(distance, error)","text":"

Definition to propagate the error in parametric intersection to find the next distance to try.

Parameters:

  • distance \u2013
           List that contains the new and the old distance.\n
  • error \u2013
           List that contains the new and the old error.\n

Returns:

  • distance ( list ) \u2013

    New distance.

  • error ( list ) \u2013

    New error.

Source code in odak/raytracing/boundary.py
def propagate_parametric_intersection_error(distance, error):\n    \"\"\"\n    Definition to propagate the error in parametric intersection to find the next distance to try.\n\n    Parameters\n    ----------\n    distance     : list\n                   List that contains the new and the old distance.\n    error        : list\n                   List that contains the new and the old error.\n\n    Returns\n    ----------\n    distance     : list\n                   New distance.\n    error        : list\n                   New error.\n    \"\"\"\n    new_distance = distance[1]-error[1] * \\\n        (distance[1]-distance[0])/(error[1]-error[0])\n    distance[0] = distance[1]\n    distance[1] = np.abs(new_distance)\n    error[0] = error[1]\n    return distance, error\n
"},{"location":"odak/raytracing/#odak.raytracing.boundary.reflect","title":"reflect(input_ray, normal)","text":"

Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.

Parameters:

  • input_ray \u2013
           A vector/ray (2x3). It can also be a list of rays (nx2x3).\n
  • normal \u2013
           A surface normal (2x3). It also be a list of normals (nx2x3).\n

Returns:

  • output_ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a reflected ray.

Source code in odak/raytracing/boundary.py
def reflect(input_ray, normal):\n    \"\"\" \n    Definition to reflect an incoming ray from a surface defined by a surface normal. Used method described in G.H. Spencer and M.V.R.K. Murty, \"General Ray-Tracing Procedure\", 1961.\n\n    Parameters\n    ----------\n    input_ray    : ndarray\n                   A vector/ray (2x3). It can also be a list of rays (nx2x3).\n    normal       : ndarray\n                   A surface normal (2x3). It also be a list of normals (nx2x3).\n\n    Returns\n    ----------\n    output_ray   : ndarray\n                   Array that contains starting points and cosines of a reflected ray.\n    \"\"\"\n    input_ray = np.asarray(input_ray)\n    normal = np.asarray(normal)\n    if len(input_ray.shape) == 2:\n        input_ray = input_ray.reshape((1, 2, 3))\n    if len(normal.shape) == 2:\n        normal = normal.reshape((1, 2, 3))\n    mu = 1\n    div = normal[:, 1, 0]**2 + normal[:, 1, 1]**2 + normal[:, 1, 2]**2\n    a = mu * (input_ray[:, 1, 0]*normal[:, 1, 0]\n              + input_ray[:, 1, 1]*normal[:, 1, 1]\n              + input_ray[:, 1, 2]*normal[:, 1, 2]) / div\n    n = np.int64(np.amax(np.array([normal.shape[0], input_ray.shape[0]])))\n    output_ray = np.zeros((n, 2, 3))\n    output_ray[:, 0] = normal[:, 0]\n    output_ray[:, 1] = input_ray[:, 1]-2*a*normal[:, 1]\n    if output_ray.shape[0] == 1:\n        output_ray = output_ray.reshape((2, 3))\n    return output_ray\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.bring_plane_to_origin","title":"bring_plane_to_origin(point, plane, shape=[10.0, 10.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0], mode='XYZ')","text":"

Definition to bring points back to reference origin with respect to a plane.

Parameters:

  • point \u2013
                 Point(s) to be tested.\n
  • shape \u2013
                 Dimensions of the rectangle along X and Y axes.\n
  • center \u2013
                 Center of the rectangle.\n
  • angles \u2013
                 Rotation angle of the rectangle.\n
  • mode \u2013
                 Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.\n

Returns:

  • transformed_points ( ndarray ) \u2013

    Point(s) that are brought back to reference origin with respect to given plane.

Source code in odak/raytracing/primitives.py
def bring_plane_to_origin(point, plane, shape=[10., 10.], center=[0., 0., 0.], angles=[0., 0., 0.], mode='XYZ'):\n    \"\"\"\n    Definition to bring points back to reference origin with respect to a plane.\n\n    Parameters\n    ----------\n    point              : ndarray\n                         Point(s) to be tested.\n    shape              : list\n                         Dimensions of the rectangle along X and Y axes.\n    center             : list\n                         Center of the rectangle.\n    angles             : list\n                         Rotation angle of the rectangle.\n    mode               : str\n                         Rotation mode of the rectangle, for more see odak.tools.rotate_point and odak.tools.rotate_points.\n\n    Returns\n    ----------\n    transformed_points : ndarray\n                         Point(s) that are brought back to reference origin with respect to given plane.\n    \"\"\"\n    if point.shape[0] == 3:\n        point = point.reshape((1, 3))\n    reverse_mode = mode[::-1]\n    angles = [-angles[0], -angles[1], -angles[2]]\n    center = np.asarray(center).reshape((1, 3))\n    transformed_points = point-center\n    transformed_points = rotate_points(\n        transformed_points,\n        angles=angles,\n        mode=reverse_mode,\n    )\n    if transformed_points.shape[0] == 1:\n        transformed_points = transformed_points.reshape((3,))\n    return transformed_points\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.center_of_triangle","title":"center_of_triangle(triangle)","text":"

Definition to calculate center of a triangle.

Parameters:

  • triangle \u2013
            An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).\n
Source code in odak/raytracing/primitives.py
def center_of_triangle(triangle):\n    \"\"\"\n    Definition to calculate center of a triangle.\n\n    Parameters\n    ----------\n    triangle      : ndarray\n                    An array that contains three points defining a triangle (Mx3). It can also parallel process many triangles (NxMx3).\n    \"\"\"\n    if len(triangle.shape) == 2:\n        triangle = triangle.reshape((1, 3, 3))\n    center = np.mean(triangle, axis=1)\n    return center\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.cylinder_function","title":"cylinder_function(point, cylinder)","text":"

Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html

Parameters:

  • cylinder \u2013
         Cylinder parameters, XYZ center and radius.\n
  • point \u2013
         Point in XYZ.\n
Return

result : float Result of the evaluation. Zero if point is on sphere.

Source code in odak/raytracing/primitives.py
def cylinder_function(point, cylinder):\n    \"\"\"\n    Definition of a cylinder function. Evaluate a point against a cylinder function. Inspired from https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html\n\n    Parameters\n    ----------\n    cylinder   : ndarray\n                 Cylinder parameters, XYZ center and radius.\n    point      : ndarray\n                 Point in XYZ.\n\n    Return\n    ----------\n    result     : float\n                 Result of the evaluation. Zero if point is on sphere.\n    \"\"\"\n    point = np.asarray(point)\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    distance = point_to_ray_distance(\n        point,\n        np.array([cylinder[0], cylinder[1], cylinder[2]], dtype=np.float64),\n        np.array([cylinder[4], cylinder[5], cylinder[6]], dtype=np.float64)\n    )\n    r = cylinder[3]\n    result = distance - r ** 2\n    return result\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.define_circle","title":"define_circle(center, radius, angles)","text":"

Definition to describe a circle in a single variable packed form.

Parameters:

  • center \u2013
      Center of a circle to be defined.\n
  • radius \u2013
      Radius of a circle to be defined.\n
  • angles \u2013
      Angular tilt of a circle.\n

Returns:

  • circle ( list ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_circle(center, radius, angles):\n    \"\"\"\n    Definition to describe a circle in a single variable packed form.\n\n    Parameters\n    ----------\n    center  : float\n              Center of a circle to be defined.\n    radius  : float\n              Radius of a circle to be defined.\n    angles  : float\n              Angular tilt of a circle.\n\n    Returns\n    ----------\n    circle  : list\n              Single variable packed form.\n    \"\"\"\n    points = define_plane(center, angles=angles)\n    circle = [\n        points,\n        center,\n        radius\n    ]\n    return circle\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.define_cylinder","title":"define_cylinder(center, radius, rotation=[0.0, 0.0, 0.0])","text":"

Definition to define a cylinder

Parameters:

  • center \u2013
         Center of a cylinder in X,Y,Z.\n
  • radius \u2013
         Radius of a cylinder along X axis.\n
  • rotation \u2013
         Direction angles in degrees for the orientation of a cylinder.\n

Returns:

  • cylinder ( ndarray ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_cylinder(center, radius, rotation=[0., 0., 0.]):\n    \"\"\"\n    Definition to define a cylinder\n\n    Parameters\n    ----------\n    center     : ndarray\n                 Center of a cylinder in X,Y,Z.\n    radius     : float\n                 Radius of a cylinder along X axis.\n    rotation   : list\n                 Direction angles in degrees for the orientation of a cylinder.\n\n    Returns\n    ----------\n    cylinder   : ndarray\n                 Single variable packed form.\n    \"\"\"\n    cylinder_ray = create_ray_from_angles(\n        np.asarray(center), np.asarray(rotation))\n    cylinder = np.array(\n        [\n            center[0],\n            center[1],\n            center[2],\n            radius,\n            center[0]+cylinder_ray[1, 0],\n            center[1]+cylinder_ray[1, 1],\n            center[2]+cylinder_ray[1, 2]\n        ],\n        dtype=np.float64\n    )\n    return cylinder\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.define_plane","title":"define_plane(point, angles=[0.0, 0.0, 0.0])","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • point \u2013
           A point that is at the center of a plane.\n
  • angles \u2013
           Rotation angles in degrees.\n

Returns:

  • plane ( ndarray ) \u2013

    Points defining plane.

Source code in odak/raytracing/primitives.py
def define_plane(point, angles=[0., 0., 0.]):\n    \"\"\" \n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    point        : ndarray\n                   A point that is at the center of a plane.\n    angles       : list\n                   Rotation angles in degrees.\n\n    Returns\n    ----------\n    plane        : ndarray\n                   Points defining plane.\n    \"\"\"\n    plane = np.array([\n        [10., 10., 0.],\n        [0., 10., 0.],\n        [0.,  0., 0.]\n    ], dtype=np.float64)\n    point = np.asarray(point)\n    for i in range(0, plane.shape[0]):\n        plane[i], _, _, _ = rotate_point(plane[i], angles=angles)\n        plane[i] = plane[i]+point\n    return plane\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.define_sphere","title":"define_sphere(center, radius)","text":"

Definition to define a sphere.

Parameters:

  • center \u2013
         Center of a sphere in X,Y,Z.\n
  • radius \u2013
         Radius of a sphere.\n

Returns:

  • sphere ( ndarray ) \u2013

    Single variable packed form.

Source code in odak/raytracing/primitives.py
def define_sphere(center, radius):\n    \"\"\"\n    Definition to define a sphere.\n\n    Parameters\n    ----------\n    center     : ndarray\n                 Center of a sphere in X,Y,Z.\n    radius     : float\n                 Radius of a sphere.\n\n    Returns\n    ----------\n    sphere     : ndarray\n                 Single variable packed form.\n    \"\"\"\n    sphere = np.array(\n        [center[0], center[1], center[2], radius], dtype=np.float64)\n    return sphere\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.is_it_on_triangle","title":"is_it_on_triangle(pointtocheck, point0, point1, point2)","text":"

Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.

Parameters:

  • pointtocheck \u2013
            Point to check.\n
  • point0 \u2013
            First point of a triangle.\n
  • point1 \u2013
            Second point of a triangle.\n
  • point2 \u2013
            Third point of a triangle.\n
Source code in odak/raytracing/primitives.py
def is_it_on_triangle(pointtocheck, point0, point1, point2):\n    \"\"\"\n    Definition to check if a given point is inside a triangle. If the given point is inside a defined triangle, this definition returns True.\n\n    Parameters\n    ----------\n    pointtocheck  : list\n                    Point to check.\n    point0        : list\n                    First point of a triangle.\n    point1        : list\n                    Second point of a triangle.\n    point2        : list\n                    Third point of a triangle.\n    \"\"\"\n    # point0, point1 and point2 are the corners of the triangle.\n    pointtocheck = np.asarray(pointtocheck).reshape(3)\n    point0 = np.asarray(point0)\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    side0 = same_side(pointtocheck, point0, point1, point2)\n    side1 = same_side(pointtocheck, point1, point0, point2)\n    side2 = same_side(pointtocheck, point2, point0, point1)\n    if side0 == True and side1 == True and side2 == True:\n        return True\n    return False\n
"},{"location":"odak/raytracing/#odak.raytracing.primitives.sphere_function","title":"sphere_function(point, sphere)","text":"

Definition of a sphere function. Evaluate a point against a sphere function.

Parameters:

  • sphere \u2013
         Sphere parameters, XYZ center and radius.\n
  • point \u2013
         Point in XYZ.\n
Return

result : float Result of the evaluation. Zero if point is on sphere.

Source code in odak/raytracing/primitives.py
def sphere_function(point, sphere):\n    \"\"\"\n    Definition of a sphere function. Evaluate a point against a sphere function.\n\n    Parameters\n    ----------\n    sphere     : ndarray\n                 Sphere parameters, XYZ center and radius.\n    point      : ndarray\n                 Point in XYZ.\n\n    Return\n    ----------\n    result     : float\n                 Result of the evaluation. Zero if point is on sphere.\n    \"\"\"\n    point = np.asarray(point)\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    result = (point[:, 0]-sphere[0])**2 + (point[:, 1]-sphere[1]\n                                           )**2 + (point[:, 2]-sphere[2])**2 - sphere[3]**2\n    return result\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.calculate_intersection_of_two_rays","title":"calculate_intersection_of_two_rays(ray0, ray1)","text":"

Definition to calculate the intersection of two rays.

Parameters:

  • ray0 \u2013
         A ray.\n
  • ray1 \u2013
         A ray.\n

Returns:

  • point ( ndarray ) \u2013

    Point in X,Y,Z.

  • distances ( ndarray ) \u2013

    Distances.

Source code in odak/raytracing/ray.py
def calculate_intersection_of_two_rays(ray0, ray1):\n    \"\"\"\n    Definition to calculate the intersection of two rays.\n\n    Parameters\n    ----------\n    ray0       : ndarray\n                 A ray.\n    ray1       : ndarray\n                 A ray.\n\n    Returns\n    ----------\n    point      : ndarray\n                 Point in X,Y,Z.\n    distances  : ndarray\n                 Distances.\n    \"\"\"\n    A = np.array([\n        [float(ray0[1][0]), float(ray1[1][0])],\n        [float(ray0[1][1]), float(ray1[1][1])],\n        [float(ray0[1][2]), float(ray1[1][2])]\n    ])\n    B = np.array([\n        ray0[0][0]-ray1[0][0],\n        ray0[0][1]-ray1[0][1],\n        ray0[0][2]-ray1[0][2]\n    ])\n    distances = np.linalg.lstsq(A, B, rcond=None)[0]\n    if np.allclose(np.dot(A, distances), B) == False:\n        distances = np.array([0, 0])\n    distances = distances[np.argsort(-distances)]\n    point = propagate_a_ray(ray0, distances[0])[0]\n    return point, distances\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.create_ray","title":"create_ray(x0y0z0, abg)","text":"

Definition to create a ray.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray.\n
  • abg \u2013
           List that contaings angles in degrees with respect to the X,Y and Z axes.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/raytracing/ray.py
def create_ray(x0y0z0, abg):\n    \"\"\"\n    Definition to create a ray.\n\n    Parameters\n    ----------\n    x0y0z0       : list\n                   List that contains X,Y and Z start locations of a ray.\n    abg          : list\n                   List that contaings angles in degrees with respect to the X,Y and Z axes.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    # Due to Python 2 -> Python 3.\n    x0, y0, z0 = x0y0z0\n    alpha, beta, gamma = abg\n    # Create a vector with the given points and angles in each direction\n    point = np.array([x0, y0, z0], dtype=np.float64)\n    alpha = np.cos(np.radians(alpha))\n    beta = np.cos(np.radians(beta))\n    gamma = np.cos(np.radians(gamma))\n    # Cosines vector.\n    cosines = np.array([alpha, beta, gamma], dtype=np.float64)\n    ray = np.array([point, cosines], dtype=np.float64)\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.create_ray_from_angles","title":"create_ray_from_angles(point, angles, mode='XYZ')","text":"

Definition to create a ray from a point and angles.

Parameters:

  • point \u2013
         Point in X,Y and Z.\n
  • angles \u2013
         Angles with X,Y,Z axes in degrees. All zeros point Z axis.\n
  • mode \u2013
         Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.\n

Returns:

  • ray ( ndarray ) \u2013

    Created ray.

Source code in odak/raytracing/ray.py
def create_ray_from_angles(point, angles, mode='XYZ'):\n    \"\"\"\n    Definition to create a ray from a point and angles.\n\n    Parameters\n    ----------\n    point      : ndarray\n                 Point in X,Y and Z.\n    angles     : ndarray\n                 Angles with X,Y,Z axes in degrees. All zeros point Z axis.\n    mode       : str\n                 Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ    ,ZXY and ZYX modes.\n\n    Returns\n    ----------\n    ray        : ndarray\n                 Created ray.\n    \"\"\"\n    if len(point.shape) == 1:\n        point = point.reshape((1, 3))\n    new_point = np.zeros(point.shape)\n    new_point[:, 2] += 5.\n    new_point = rotate_points(new_point, angles, mode=mode, offset=point[:, 0])\n    ray = create_ray_from_two_points(point, new_point)\n    if ray.shape[0] == 1:\n        ray = ray.reshape((2, 3))\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.create_ray_from_two_points","title":"create_ray_from_two_points(x0y0z0, x1y1z1)","text":"

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : list\n                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n    x1y1z1       : list\n                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)\n    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.reshape((1, 3))\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.reshape((1, 3))\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)\n    s[s == 0] = np.nan\n    cosines = np.zeros((xdiff.shape[0], 3))\n    cosines[:, 0] = xdiff/s\n    cosines[:, 1] = ydiff/s\n    cosines[:, 2] = zdiff/s\n    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    if ray.shape[0] == 1:\n        ray = ray.reshape((2, 3))\n    return ray\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.find_nearest_points","title":"find_nearest_points(ray0, ray1)","text":"

Find the nearest points on given rays with respect to the other ray.

Parameters:

  • ray0 \u2013
         A ray.\n
  • ray1 \u2013
         A ray.\n

Returns:

  • c0 ( ndarray ) \u2013

    Closest point on ray0.

  • c1 ( ndarray ) \u2013

    Closest point on ray1.

Source code in odak/raytracing/ray.py
def find_nearest_points(ray0, ray1):\n    \"\"\"\n    Find the nearest points on given rays with respect to the other ray.\n\n    Parameters\n    ----------\n    ray0       : ndarray\n                 A ray.\n    ray1       : ndarray\n                 A ray.\n\n    Returns\n    ----------\n    c0         : ndarray\n                 Closest point on ray0.\n    c1         : ndarray\n                 Closest point on ray1.\n    \"\"\"\n    p0 = ray0[0].reshape(3,)\n    d0 = ray0[1].reshape(3,)\n    p1 = ray1[0].reshape(3,)\n    d1 = ray1[1].reshape(3,)\n    n = np.cross(d0, d1)\n    if np.all(n) == 0:\n        point, distances = calculate_intersection_of_two_rays(ray0, ray1)\n        c0 = c1 = point\n    else:\n        n0 = np.cross(d0, n)\n        n1 = np.cross(d1, n)\n        c0 = p0+(np.dot((p1-p0), n1)/np.dot(d0, n1))*d0\n        c1 = p1+(np.dot((p0-p1), n0)/np.dot(d1, n0))*d1\n    return c0, c1\n
"},{"location":"odak/raytracing/#odak.raytracing.ray.propagate_a_ray","title":"propagate_a_ray(ray, distance)","text":"

Definition to propagate a ray at a certain given distance.

Parameters:

  • ray \u2013
         A ray.\n
  • distance \u2013
         Distance.\n

Returns:

  • new_ray ( ndarray ) \u2013

    Propagated ray.

Source code in odak/raytracing/ray.py
def propagate_a_ray(ray, distance):\n    \"\"\"\n    Definition to propagate a ray at a certain given distance.\n\n    Parameters\n    ----------\n    ray        : ndarray\n                 A ray.\n    distance   : float\n                 Distance.\n\n    Returns\n    ----------\n    new_ray    : ndarray\n                 Propagated ray.\n    \"\"\"\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    new_ray = np.copy(ray)\n    new_ray[:, 0, 0] = distance*new_ray[:, 1, 0] + new_ray[:, 0, 0]\n    new_ray[:, 0, 1] = distance*new_ray[:, 1, 1] + new_ray[:, 0, 1]\n    new_ray[:, 0, 2] = distance*new_ray[:, 1, 2] + new_ray[:, 0, 2]\n    if new_ray.shape[0] == 1:\n        new_ray = new_ray.reshape((2, 3))\n    return new_ray\n
"},{"location":"odak/tools/","title":"odak.tools","text":"

odak.tools

Provides necessary definitions for general tools used across the library.

A class to work with latex documents.

Source code in odak/tools/latex.py
class latex():\n    \"\"\"\n    A class to work with latex documents.\n    \"\"\"\n    def __init__(\n                 self,\n                 filename\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        filename     : str\n                       Source filename (i.e. sample.tex).\n        \"\"\"\n        self.filename = filename\n        self.content = read_text_file(self.filename)\n        self.content_type = []\n        self.latex_dictionary = [\n                                 '\\\\documentclass',\n                                 '\\\\if',\n                                 '\\\\pdf',\n                                 '\\\\else',\n                                 '\\\\fi',\n                                 '\\\\vgtc',\n                                 '\\\\teaser',\n                                 '\\\\abstract',\n                                 '\\\\CCS',\n                                 '\\\\usepackage',\n                                 '\\\\PassOptionsToPackage',\n                                 '\\\\definecolor',\n                                 '\\\\AtBeginDocument',\n                                 '\\\\providecommand',\n                                 '\\\\setcopyright',\n                                 '\\\\copyrightyear',\n                                 '\\\\acmYear',\n                                 '\\\\citestyle',\n                                 '\\\\newcommand',\n                                 '\\\\acmDOI',\n                                 '\\\\newabbreviation',\n                                 '\\\\global',\n                                 '\\\\begin{document}',\n                                 '\\\\author',\n                                 '\\\\affiliation',\n                                 '\\\\email',\n                                 '\\\\institution',\n                                 '\\\\streetaddress',\n                                 '\\\\city',\n                                 '\\\\country',\n                                 '\\\\postcode',\n                                 '\\\\ccsdesc',\n                                 '\\\\received',\n                                 '\\\\includegraphics',\n                                 '\\\\caption',\n                                 '\\\\centering',\n                                 '\\\\label',\n                                 '\\\\maketitle',\n                                 '\\\\toprule',\n                                 '\\\\multirow',\n                                 '\\\\multicolumn',\n                                 '\\\\cmidrule',\n                                 '\\\\addlinespace',\n                                 '\\\\midrule',\n                                 '\\\\cellcolor',\n                                 '\\\\bibliography',\n                                 '}',\n                                 '\\\\title',\n                                 '</ccs2012>',\n                                 '\\\\bottomrule',\n                                 '<concept>',\n                                 '<concept',\n                                 '<ccs',\n                                 '\\\\item',\n                                 '</concept',\n                                 '\\\\begin{abstract}',\n                                 '\\\\end{abstract}',\n                                 '\\\\endinput',\n                                 '\\\\\\\\'\n                                ]\n        self.latex_begin_dictionary = [\n                                       '\\\\begin{figure}',\n                                       '\\\\begin{figure*}',\n                                       '\\\\begin{equation}',\n                                       '\\\\begin{CCSXML}',\n                                       '\\\\begin{teaserfigure}',\n                                       '\\\\begin{table*}',\n                                       '\\\\begin{table}',\n                                       '\\\\begin{gather}',\n                                       '\\\\begin{align}',\n                                      ]\n        self.latex_end_dictionary = [\n                                     '\\\\end{figure}',\n                                     '\\\\end{figure*}',\n                                     '\\\\end{equation}',\n                                     '\\\\end{CCSXML}',\n                                     '\\\\end{teaserfigure}',\n                                     '\\\\end{table*}',\n                                     '\\\\end{table}',\n                                     '\\\\end{gather}',\n                                     '\\\\end{align}',\n                                    ]\n        self._label_lines()\n\n\n    def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n        \"\"\"\n        Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n        Parameters\n        ----------\n        begin_dictionary     : list\n                               Pythonic list containing latex syntax for begin commands (i.e. \\\\begin{align}).\n        end_dictionary       : list\n                               Pythonic list containing latex syntax for end commands (i.e. \\\\end{table}).\n        syntax_dictionary    : list\n                               Pythonic list containing latex syntax (i.e. \\\\item).\n\n        \"\"\"\n        self.latex_begin_dictionary = begin_dictionary\n        self.latex_end_dictionary = end_dictionary\n        self.latex_dictionary = syntax_dictionary\n        self._label_lines\n\n\n    def _label_lines(self):\n        \"\"\"\n        Internal function for labelling lines.\n        \"\"\"\n        content_type_flag = False\n        for line_id, line in enumerate(self.content):\n            while len(line) > 0 and line[0] == ' ':\n                 line = line[1::]\n            self.content[line_id] = line\n            if len(line) == 0:\n                content_type = 'empty'\n            elif line[0] == '%':\n                content_type = 'comment'\n            else:\n                content_type = 'text'\n            for syntax in self.latex_begin_dictionary:\n                if line.find(syntax) != -1:\n                    content_type_flag = True\n                    content_type = 'latex'\n            for syntax in self.latex_dictionary:\n                if line.find(syntax) != -1:\n                    content_type = 'latex'\n            if content_type_flag == True:\n                content_type = 'latex'\n                for syntax in self.latex_end_dictionary:\n                    if line.find(syntax) != -1:\n                         content_type_flag = False\n            self.content_type.append(content_type)\n\n\n    def get_line_count(self):\n        \"\"\"\n        Definition to get the line count.\n\n\n        Returns\n        -------\n        line_count     : int\n                         Number of lines in the loaded latex document.\n        \"\"\"\n        self.line_count = len(self.content)\n        return self.line_count\n\n\n    def get_line(self, line_id = 0):\n        \"\"\"\n        Definition to get a specific line by inputting a line nunber.\n\n\n        Returns\n        ----------\n        line           : str\n                         Requested line.\n        content_type   : str\n                         Line's content type (e.g., latex, comment, text).\n        \"\"\"\n        line = self.content[line_id]\n        content_type = self.content_type[line_id]\n        return line, content_type\n

A class to work with markdown documents.

Source code in odak/tools/markdown.py
class markdown():\n    \"\"\"\n    A class to work with markdown documents.\n    \"\"\"\n    def __init__(\n                 self,\n                 filename\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        filename     : str\n                       Source filename (i.e. sample.md).\n        \"\"\"\n        self.filename = filename\n        self.content = read_text_file(self.filename)\n        self.content_type = []\n        self.markdown_dictionary = [\n                                     '#',\n                                   ]\n        self.markdown_begin_dictionary = [\n                                          '```bash',\n                                          '```python',\n                                          '```',\n                                         ]\n        self.markdown_end_dictionary = [\n                                        '```',\n                                       ]\n        self._label_lines()\n\n\n    def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n        \"\"\"\n        Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n        Parameters\n        ----------\n        begin_dictionary     : list\n                               Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n        end_dictionary       : list\n                               Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n        syntax_dictionary    : list\n                               Pythonic list containing markdown syntax (i.e. \\\\item).\n\n        \"\"\"\n        self.markdown_begin_dictionary = begin_dictionary\n        self.markdown_end_dictionary = end_dictionary\n        self.markdown_dictionary = syntax_dictionary\n        self._label_lines\n\n\n    def _label_lines(self):\n        \"\"\"\n        Internal function for labelling lines.\n        \"\"\"\n        content_type_flag = False\n        for line_id, line in enumerate(self.content):\n            while len(line) > 0 and line[0] == ' ':\n                 line = line[1::]\n            self.content[line_id] = line\n            if len(line) == 0:\n                content_type = 'empty'\n            elif line[0] == '%':\n                content_type = 'comment'\n            else:\n                content_type = 'text'\n            for syntax in self.markdown_begin_dictionary:\n                if line.find(syntax) != -1:\n                    content_type_flag = True\n                    content_type = 'markdown'\n            for syntax in self.markdown_dictionary:\n                if line.find(syntax) != -1:\n                    content_type = 'markdown'\n            if content_type_flag == True:\n                content_type = 'markdown'\n                for syntax in self.markdown_end_dictionary:\n                    if line.find(syntax) != -1:\n                         content_type_flag = False\n            self.content_type.append(content_type)\n\n\n    def get_line_count(self):\n        \"\"\"\n        Definition to get the line count.\n\n\n        Returns\n        -------\n        line_count     : int\n                         Number of lines in the loaded markdown document.\n        \"\"\"\n        self.line_count = len(self.content)\n        return self.line_count\n\n\n    def get_line(self, line_id = 0):\n        \"\"\"\n        Definition to get a specific line by inputting a line nunber.\n\n\n        Returns\n        ----------\n        line           : str\n                         Requested line.\n        content_type   : str\n                         Line's content type (e.g., markdown, comment, text).\n        \"\"\"\n        line = self.content[line_id]\n        content_type = self.content_type[line_id]\n        return line, content_type\n
"},{"location":"odak/tools/#odak.tools.latex","title":"latex","text":"

A class to work with latex documents.

Source code in odak/tools/latex.py
class latex():\n    \"\"\"\n    A class to work with latex documents.\n    \"\"\"\n    def __init__(\n                 self,\n                 filename\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        filename     : str\n                       Source filename (i.e. sample.tex).\n        \"\"\"\n        self.filename = filename\n        self.content = read_text_file(self.filename)\n        self.content_type = []\n        self.latex_dictionary = [\n                                 '\\\\documentclass',\n                                 '\\\\if',\n                                 '\\\\pdf',\n                                 '\\\\else',\n                                 '\\\\fi',\n                                 '\\\\vgtc',\n                                 '\\\\teaser',\n                                 '\\\\abstract',\n                                 '\\\\CCS',\n                                 '\\\\usepackage',\n                                 '\\\\PassOptionsToPackage',\n                                 '\\\\definecolor',\n                                 '\\\\AtBeginDocument',\n                                 '\\\\providecommand',\n                                 '\\\\setcopyright',\n                                 '\\\\copyrightyear',\n                                 '\\\\acmYear',\n                                 '\\\\citestyle',\n                                 '\\\\newcommand',\n                                 '\\\\acmDOI',\n                                 '\\\\newabbreviation',\n                                 '\\\\global',\n                                 '\\\\begin{document}',\n                                 '\\\\author',\n                                 '\\\\affiliation',\n                                 '\\\\email',\n                                 '\\\\institution',\n                                 '\\\\streetaddress',\n                                 '\\\\city',\n                                 '\\\\country',\n                                 '\\\\postcode',\n                                 '\\\\ccsdesc',\n                                 '\\\\received',\n                                 '\\\\includegraphics',\n                                 '\\\\caption',\n                                 '\\\\centering',\n                                 '\\\\label',\n                                 '\\\\maketitle',\n                                 '\\\\toprule',\n                                 '\\\\multirow',\n                                 '\\\\multicolumn',\n                                 '\\\\cmidrule',\n                                 '\\\\addlinespace',\n                                 '\\\\midrule',\n                                 '\\\\cellcolor',\n                                 '\\\\bibliography',\n                                 '}',\n                                 '\\\\title',\n                                 '</ccs2012>',\n                                 '\\\\bottomrule',\n                                 '<concept>',\n                                 '<concept',\n                                 '<ccs',\n                                 '\\\\item',\n                                 '</concept',\n                                 '\\\\begin{abstract}',\n                                 '\\\\end{abstract}',\n                                 '\\\\endinput',\n                                 '\\\\\\\\'\n                                ]\n        self.latex_begin_dictionary = [\n                                       '\\\\begin{figure}',\n                                       '\\\\begin{figure*}',\n                                       '\\\\begin{equation}',\n                                       '\\\\begin{CCSXML}',\n                                       '\\\\begin{teaserfigure}',\n                                       '\\\\begin{table*}',\n                                       '\\\\begin{table}',\n                                       '\\\\begin{gather}',\n                                       '\\\\begin{align}',\n                                      ]\n        self.latex_end_dictionary = [\n                                     '\\\\end{figure}',\n                                     '\\\\end{figure*}',\n                                     '\\\\end{equation}',\n                                     '\\\\end{CCSXML}',\n                                     '\\\\end{teaserfigure}',\n                                     '\\\\end{table*}',\n                                     '\\\\end{table}',\n                                     '\\\\end{gather}',\n                                     '\\\\end{align}',\n                                    ]\n        self._label_lines()\n\n\n    def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n        \"\"\"\n        Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n        Parameters\n        ----------\n        begin_dictionary     : list\n                               Pythonic list containing latex syntax for begin commands (i.e. \\\\begin{align}).\n        end_dictionary       : list\n                               Pythonic list containing latex syntax for end commands (i.e. \\\\end{table}).\n        syntax_dictionary    : list\n                               Pythonic list containing latex syntax (i.e. \\\\item).\n\n        \"\"\"\n        self.latex_begin_dictionary = begin_dictionary\n        self.latex_end_dictionary = end_dictionary\n        self.latex_dictionary = syntax_dictionary\n        self._label_lines\n\n\n    def _label_lines(self):\n        \"\"\"\n        Internal function for labelling lines.\n        \"\"\"\n        content_type_flag = False\n        for line_id, line in enumerate(self.content):\n            while len(line) > 0 and line[0] == ' ':\n                 line = line[1::]\n            self.content[line_id] = line\n            if len(line) == 0:\n                content_type = 'empty'\n            elif line[0] == '%':\n                content_type = 'comment'\n            else:\n                content_type = 'text'\n            for syntax in self.latex_begin_dictionary:\n                if line.find(syntax) != -1:\n                    content_type_flag = True\n                    content_type = 'latex'\n            for syntax in self.latex_dictionary:\n                if line.find(syntax) != -1:\n                    content_type = 'latex'\n            if content_type_flag == True:\n                content_type = 'latex'\n                for syntax in self.latex_end_dictionary:\n                    if line.find(syntax) != -1:\n                         content_type_flag = False\n            self.content_type.append(content_type)\n\n\n    def get_line_count(self):\n        \"\"\"\n        Definition to get the line count.\n\n\n        Returns\n        -------\n        line_count     : int\n                         Number of lines in the loaded latex document.\n        \"\"\"\n        self.line_count = len(self.content)\n        return self.line_count\n\n\n    def get_line(self, line_id = 0):\n        \"\"\"\n        Definition to get a specific line by inputting a line nunber.\n\n\n        Returns\n        ----------\n        line           : str\n                         Requested line.\n        content_type   : str\n                         Line's content type (e.g., latex, comment, text).\n        \"\"\"\n        line = self.content[line_id]\n        content_type = self.content_type[line_id]\n        return line, content_type\n
"},{"location":"odak/tools/#odak.tools.latex.__init__","title":"__init__(filename)","text":"

Parameters:

  • filename \u2013
           Source filename (i.e. sample.tex).\n
Source code in odak/tools/latex.py
def __init__(\n             self,\n             filename\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    filename     : str\n                   Source filename (i.e. sample.tex).\n    \"\"\"\n    self.filename = filename\n    self.content = read_text_file(self.filename)\n    self.content_type = []\n    self.latex_dictionary = [\n                             '\\\\documentclass',\n                             '\\\\if',\n                             '\\\\pdf',\n                             '\\\\else',\n                             '\\\\fi',\n                             '\\\\vgtc',\n                             '\\\\teaser',\n                             '\\\\abstract',\n                             '\\\\CCS',\n                             '\\\\usepackage',\n                             '\\\\PassOptionsToPackage',\n                             '\\\\definecolor',\n                             '\\\\AtBeginDocument',\n                             '\\\\providecommand',\n                             '\\\\setcopyright',\n                             '\\\\copyrightyear',\n                             '\\\\acmYear',\n                             '\\\\citestyle',\n                             '\\\\newcommand',\n                             '\\\\acmDOI',\n                             '\\\\newabbreviation',\n                             '\\\\global',\n                             '\\\\begin{document}',\n                             '\\\\author',\n                             '\\\\affiliation',\n                             '\\\\email',\n                             '\\\\institution',\n                             '\\\\streetaddress',\n                             '\\\\city',\n                             '\\\\country',\n                             '\\\\postcode',\n                             '\\\\ccsdesc',\n                             '\\\\received',\n                             '\\\\includegraphics',\n                             '\\\\caption',\n                             '\\\\centering',\n                             '\\\\label',\n                             '\\\\maketitle',\n                             '\\\\toprule',\n                             '\\\\multirow',\n                             '\\\\multicolumn',\n                             '\\\\cmidrule',\n                             '\\\\addlinespace',\n                             '\\\\midrule',\n                             '\\\\cellcolor',\n                             '\\\\bibliography',\n                             '}',\n                             '\\\\title',\n                             '</ccs2012>',\n                             '\\\\bottomrule',\n                             '<concept>',\n                             '<concept',\n                             '<ccs',\n                             '\\\\item',\n                             '</concept',\n                             '\\\\begin{abstract}',\n                             '\\\\end{abstract}',\n                             '\\\\endinput',\n                             '\\\\\\\\'\n                            ]\n    self.latex_begin_dictionary = [\n                                   '\\\\begin{figure}',\n                                   '\\\\begin{figure*}',\n                                   '\\\\begin{equation}',\n                                   '\\\\begin{CCSXML}',\n                                   '\\\\begin{teaserfigure}',\n                                   '\\\\begin{table*}',\n                                   '\\\\begin{table}',\n                                   '\\\\begin{gather}',\n                                   '\\\\begin{align}',\n                                  ]\n    self.latex_end_dictionary = [\n                                 '\\\\end{figure}',\n                                 '\\\\end{figure*}',\n                                 '\\\\end{equation}',\n                                 '\\\\end{CCSXML}',\n                                 '\\\\end{teaserfigure}',\n                                 '\\\\end{table*}',\n                                 '\\\\end{table}',\n                                 '\\\\end{gather}',\n                                 '\\\\end{align}',\n                                ]\n    self._label_lines()\n
"},{"location":"odak/tools/#odak.tools.latex.get_line","title":"get_line(line_id=0)","text":"

Definition to get a specific line by inputting a line nunber.

Returns:

  • line ( str ) \u2013

    Requested line.

  • content_type ( str ) \u2013

    Line's content type (e.g., latex, comment, text).

Source code in odak/tools/latex.py
def get_line(self, line_id = 0):\n    \"\"\"\n    Definition to get a specific line by inputting a line nunber.\n\n\n    Returns\n    ----------\n    line           : str\n                     Requested line.\n    content_type   : str\n                     Line's content type (e.g., latex, comment, text).\n    \"\"\"\n    line = self.content[line_id]\n    content_type = self.content_type[line_id]\n    return line, content_type\n
"},{"location":"odak/tools/#odak.tools.latex.get_line_count","title":"get_line_count()","text":"

Definition to get the line count.

Returns:

  • line_count ( int ) \u2013

    Number of lines in the loaded latex document.

Source code in odak/tools/latex.py
def get_line_count(self):\n    \"\"\"\n    Definition to get the line count.\n\n\n    Returns\n    -------\n    line_count     : int\n                     Number of lines in the loaded latex document.\n    \"\"\"\n    self.line_count = len(self.content)\n    return self.line_count\n
"},{"location":"odak/tools/#odak.tools.latex.set_latex_dictonaries","title":"set_latex_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary)","text":"

Set document specific dictionaries so that the lines could be labelled in accordance.

Parameters:

  • begin_dictionary \u2013
                   Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).\n
  • end_dictionary \u2013
                   Pythonic list containing latex syntax for end commands (i.e. \\end{table}).\n
  • syntax_dictionary \u2013
                   Pythonic list containing latex syntax (i.e. \\item).\n
Source code in odak/tools/latex.py
def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n    \"\"\"\n    Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n    Parameters\n    ----------\n    begin_dictionary     : list\n                           Pythonic list containing latex syntax for begin commands (i.e. \\\\begin{align}).\n    end_dictionary       : list\n                           Pythonic list containing latex syntax for end commands (i.e. \\\\end{table}).\n    syntax_dictionary    : list\n                           Pythonic list containing latex syntax (i.e. \\\\item).\n\n    \"\"\"\n    self.latex_begin_dictionary = begin_dictionary\n    self.latex_end_dictionary = end_dictionary\n    self.latex_dictionary = syntax_dictionary\n    self._label_lines\n
"},{"location":"odak/tools/#odak.tools.markdown","title":"markdown","text":"

A class to work with markdown documents.

Source code in odak/tools/markdown.py
class markdown():\n    \"\"\"\n    A class to work with markdown documents.\n    \"\"\"\n    def __init__(\n                 self,\n                 filename\n                ):\n        \"\"\"\n        Parameters\n        ----------\n        filename     : str\n                       Source filename (i.e. sample.md).\n        \"\"\"\n        self.filename = filename\n        self.content = read_text_file(self.filename)\n        self.content_type = []\n        self.markdown_dictionary = [\n                                     '#',\n                                   ]\n        self.markdown_begin_dictionary = [\n                                          '```bash',\n                                          '```python',\n                                          '```',\n                                         ]\n        self.markdown_end_dictionary = [\n                                        '```',\n                                       ]\n        self._label_lines()\n\n\n    def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n        \"\"\"\n        Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n        Parameters\n        ----------\n        begin_dictionary     : list\n                               Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n        end_dictionary       : list\n                               Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n        syntax_dictionary    : list\n                               Pythonic list containing markdown syntax (i.e. \\\\item).\n\n        \"\"\"\n        self.markdown_begin_dictionary = begin_dictionary\n        self.markdown_end_dictionary = end_dictionary\n        self.markdown_dictionary = syntax_dictionary\n        self._label_lines\n\n\n    def _label_lines(self):\n        \"\"\"\n        Internal function for labelling lines.\n        \"\"\"\n        content_type_flag = False\n        for line_id, line in enumerate(self.content):\n            while len(line) > 0 and line[0] == ' ':\n                 line = line[1::]\n            self.content[line_id] = line\n            if len(line) == 0:\n                content_type = 'empty'\n            elif line[0] == '%':\n                content_type = 'comment'\n            else:\n                content_type = 'text'\n            for syntax in self.markdown_begin_dictionary:\n                if line.find(syntax) != -1:\n                    content_type_flag = True\n                    content_type = 'markdown'\n            for syntax in self.markdown_dictionary:\n                if line.find(syntax) != -1:\n                    content_type = 'markdown'\n            if content_type_flag == True:\n                content_type = 'markdown'\n                for syntax in self.markdown_end_dictionary:\n                    if line.find(syntax) != -1:\n                         content_type_flag = False\n            self.content_type.append(content_type)\n\n\n    def get_line_count(self):\n        \"\"\"\n        Definition to get the line count.\n\n\n        Returns\n        -------\n        line_count     : int\n                         Number of lines in the loaded markdown document.\n        \"\"\"\n        self.line_count = len(self.content)\n        return self.line_count\n\n\n    def get_line(self, line_id = 0):\n        \"\"\"\n        Definition to get a specific line by inputting a line nunber.\n\n\n        Returns\n        ----------\n        line           : str\n                         Requested line.\n        content_type   : str\n                         Line's content type (e.g., markdown, comment, text).\n        \"\"\"\n        line = self.content[line_id]\n        content_type = self.content_type[line_id]\n        return line, content_type\n
"},{"location":"odak/tools/#odak.tools.markdown.__init__","title":"__init__(filename)","text":"

Parameters:

  • filename \u2013
           Source filename (i.e. sample.md).\n
Source code in odak/tools/markdown.py
def __init__(\n             self,\n             filename\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    filename     : str\n                   Source filename (i.e. sample.md).\n    \"\"\"\n    self.filename = filename\n    self.content = read_text_file(self.filename)\n    self.content_type = []\n    self.markdown_dictionary = [\n                                 '#',\n                               ]\n    self.markdown_begin_dictionary = [\n                                      '```bash',\n                                      '```python',\n                                      '```',\n                                     ]\n    self.markdown_end_dictionary = [\n                                    '```',\n                                   ]\n    self._label_lines()\n
"},{"location":"odak/tools/#odak.tools.markdown.get_line","title":"get_line(line_id=0)","text":"

Definition to get a specific line by inputting a line nunber.

Returns:

  • line ( str ) \u2013

    Requested line.

  • content_type ( str ) \u2013

    Line's content type (e.g., markdown, comment, text).

Source code in odak/tools/markdown.py
def get_line(self, line_id = 0):\n    \"\"\"\n    Definition to get a specific line by inputting a line nunber.\n\n\n    Returns\n    ----------\n    line           : str\n                     Requested line.\n    content_type   : str\n                     Line's content type (e.g., markdown, comment, text).\n    \"\"\"\n    line = self.content[line_id]\n    content_type = self.content_type[line_id]\n    return line, content_type\n
"},{"location":"odak/tools/#odak.tools.markdown.get_line_count","title":"get_line_count()","text":"

Definition to get the line count.

Returns:

  • line_count ( int ) \u2013

    Number of lines in the loaded markdown document.

Source code in odak/tools/markdown.py
def get_line_count(self):\n    \"\"\"\n    Definition to get the line count.\n\n\n    Returns\n    -------\n    line_count     : int\n                     Number of lines in the loaded markdown document.\n    \"\"\"\n    self.line_count = len(self.content)\n    return self.line_count\n
"},{"location":"odak/tools/#odak.tools.markdown.set_dictonaries","title":"set_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary)","text":"

Set document specific dictionaries so that the lines could be labelled in accordance.

Parameters:

  • begin_dictionary \u2013
                   Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n
  • end_dictionary \u2013
                   Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n
  • syntax_dictionary \u2013
                   Pythonic list containing markdown syntax (i.e. \\item).\n
Source code in odak/tools/markdown.py
def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n    \"\"\"\n    Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n    Parameters\n    ----------\n    begin_dictionary     : list\n                           Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n    end_dictionary       : list\n                           Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n    syntax_dictionary    : list\n                           Pythonic list containing markdown syntax (i.e. \\\\item).\n\n    \"\"\"\n    self.markdown_begin_dictionary = begin_dictionary\n    self.markdown_end_dictionary = end_dictionary\n    self.markdown_dictionary = syntax_dictionary\n    self._label_lines\n
"},{"location":"odak/tools/#odak.tools.batch_of_rays","title":"batch_of_rays(entry, exit)","text":"

Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.

Parameters:

  • entry \u2013
         Either a single point with size of 3 or multiple points with the size of nx3.\n
  • exit \u2013
         Either a single point with size of 3 or multiple points with the size of nx3.\n

Returns:

  • rays ( ndarray ) \u2013

    Generated batch of rays.

Source code in odak/tools/sample.py
def batch_of_rays(entry, exit):\n    \"\"\"\n    Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.\n\n    Parameters\n    ----------\n    entry      : ndarray\n                 Either a single point with size of 3 or multiple points with the size of nx3.\n    exit       : ndarray\n                 Either a single point with size of 3 or multiple points with the size of nx3.\n\n    Returns\n    ----------\n    rays       : ndarray\n                 Generated batch of rays.\n    \"\"\"\n    norays = np.array([0, 0])\n    if len(entry.shape) == 1:\n        entry = entry.reshape((1, 3))\n    if len(exit.shape) == 1:\n        exit = exit.reshape((1, 3))\n    norays = np.amax(np.asarray([entry.shape[0], exit.shape[0]]))\n    if norays > exit.shape[0]:\n        exit = np.repeat(exit, norays, axis=0)\n    elif norays > entry.shape[0]:\n        entry = np.repeat(entry, norays, axis=0)\n    rays = []\n    norays = int(norays)\n    for i in range(norays):\n        rays.append(\n            create_ray_from_two_points(\n                entry[i],\n                exit[i]\n            )\n        )\n    rays = np.asarray(rays)\n    return rays\n
"},{"location":"odak/tools/#odak.tools.blur_gaussian","title":"blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3])","text":"

A definition to blur a field using a Gaussian kernel.

Parameters:

  • field \u2013
            MxN field.\n
  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n

Returns:

  • blurred_field ( ndarray ) \u2013

    Blurred field.

Source code in odak/tools/matrix.py
def blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]):\n    \"\"\"\n    A definition to blur a field using a Gaussian kernel.\n\n    Parameters\n    ----------\n    field         : ndarray\n                    MxN field.\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n\n    Returns\n    ----------\n    blurred_field : ndarray\n                    Blurred field.\n    \"\"\"\n    kernel = generate_2d_gaussian(kernel_length, nsigma)\n    kernel = zero_pad(kernel, field.shape)\n    blurred_field = convolve2d(field, kernel)\n    blurred_field = blurred_field/np.amax(blurred_field)\n    return blurred_field\n
"},{"location":"odak/tools/#odak.tools.box_volume_sample","title":"box_volume_sample(no=[10, 10, 10], size=[100.0, 100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples in a box volume.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the volume.\n
  • center \u2013
          Center location of the volume.\n
  • angles \u2013
          Tilt of the volume.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def box_volume_sample(no=[10, 10, 10], size=[100., 100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples in a box volume.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the volume.\n    center      : list\n                  Center location of the volume.\n    angles      : list\n                  Tilt of the volume.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], no[2], 3))\n    x, y, z = np.mgrid[0:no[0], 0:no[1], 0:no[2]]\n    step = [\n        size[0]/no[0],\n        size[1]/no[1],\n        size[2]/no[2]\n    ]\n    samples[:, :, :, 0] = x*step[0]+step[0]/2.-size[0]/2.\n    samples[:, :, :, 1] = y*step[1]+step[1]/2.-size[1]/2.\n    samples[:, :, :, 2] = z*step[2]+step[2]/2.-size[2]/2.\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1]*samples.shape[2], samples.shape[3]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.check_directory","title":"check_directory(directory)","text":"

Definition to check if a directory exist. If it doesn't exist, this definition will create one.

Parameters:

  • directory \u2013
            Full directory path.\n
Source code in odak/tools/file.py
def check_directory(directory):\n    \"\"\"\n    Definition to check if a directory exist. If it doesn't exist, this definition will create one.\n\n\n    Parameters\n    ----------\n    directory     : str\n                    Full directory path.\n    \"\"\"\n    if not os.path.exists(expanduser(directory)):\n        os.makedirs(expanduser(directory))\n        return False\n    return True\n
"},{"location":"odak/tools/#odak.tools.circular_sample","title":"circular_sample(no=[10, 10], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples inside a circle over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_sample(no=[10, 10], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples inside a circle over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0]+1, no[1]+1, 3))\n    r_angles, r = np.mgrid[0:no[0]+1, 0:no[1]+1]\n    r = r/np.amax(r)*radius\n    r_angles = r_angles/np.amax(r_angles)*np.pi*2\n    samples[:, :, 0] = r*np.cos(r_angles)\n    samples[:, :, 1] = r*np.sin(r_angles)\n    samples = samples[1:no[0]+1, 1:no[1]+1, :]\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1], samples.shape[2]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.circular_uniform_random_sample","title":"circular_uniform_random_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate sample inside a circle uniformly but randomly.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_uniform_random_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\" \n    Definition to generate sample inside a circle uniformly but randomly.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.empty((0, 3))\n    rs = np.sqrt(np.random.uniform(0, 1, no[0]))\n    angs = np.random.uniform(0, 2*np.pi, no[1])\n    for i in rs:\n        for angle in angs:\n            r = radius*i\n            point = np.array(\n                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])\n            samples = np.vstack((samples, point))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.circular_uniform_sample","title":"circular_uniform_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate sample inside a circle uniformly.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_uniform_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate sample inside a circle uniformly.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.empty((0, 3))\n    for i in range(0, no[0]):\n        r = i/no[0]*radius\n        ang_no = no[1]*i/no[0]\n        for j in range(0, int(no[1]*i/no[0])):\n            angle = j/ang_no*2*np.pi\n            point = np.array(\n                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])\n            samples = np.vstack((samples, point))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.closest_point_to_a_ray","title":"closest_point_to_a_ray(point, ray)","text":"

Definition to calculate the point on a ray that is closest to given point.

Parameters:

  • point \u2013
            Given point in X,Y,Z.\n
  • ray \u2013
            Given ray.\n

Returns:

  • closest_point ( ndarray ) \u2013

    Calculated closest point.

Source code in odak/tools/vector.py
def closest_point_to_a_ray(point, ray):\n    \"\"\"\n    Definition to calculate the point on a ray that is closest to given point.\n\n    Parameters\n    ----------\n    point         : list\n                    Given point in X,Y,Z.\n    ray           : ndarray\n                    Given ray.\n\n    Returns\n    ---------\n    closest_point : ndarray\n                    Calculated closest point.\n    \"\"\"\n    from odak.raytracing import propagate_a_ray\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    p0 = ray[:, 0]\n    p1 = propagate_a_ray(ray, 1.)\n    if len(p1.shape) == 2:\n        p1 = p1.reshape((1, 2, 3))\n    p1 = p1[:, 0]\n    p1 = p1.reshape(3)\n    p0 = p0.reshape(3)\n    point = point.reshape(3)\n    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)\n    closest_point = propagate_a_ray(ray, closest_distance)[0]\n    return closest_point\n
"},{"location":"odak/tools/#odak.tools.convert_bytes","title":"convert_bytes(num)","text":"

A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.

Parameters:

  • num \u2013
         Size in bytes\n

Returns:

  • num ( float ) \u2013

    Size in new unit.

  • x ( str ) \u2013

    New unit bytes, KB, MB, GB or TB.

Source code in odak/tools/file.py
def convert_bytes(num):\n    \"\"\"\n    A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.\n\n\n    Parameters\n    ----------\n    num        : float\n                 Size in bytes\n\n\n    Returns\n    ----------\n    num        : float\n                 Size in new unit.\n    x          : str\n                 New unit bytes, KB, MB, GB or TB.\n    \"\"\"\n    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:\n        if num < 1024.0:\n            return num, x\n        num /= 1024.0\n    return None, None\n
"},{"location":"odak/tools/#odak.tools.convert_to_numpy","title":"convert_to_numpy(a)","text":"

A definition to convert Torch to Numpy.

Parameters:

  • a \u2013
         Input Torch array.\n

Returns:

  • b ( ndarray ) \u2013

    Converted array.

Source code in odak/tools/conversions.py
def convert_to_numpy(a):\n    \"\"\"\n    A definition to convert Torch to Numpy.\n\n    Parameters\n    ----------\n    a          : torch.Tensor\n                 Input Torch array.\n\n    Returns\n    ----------\n    b          : numpy.ndarray\n                 Converted array.\n    \"\"\"\n    b = a.to('cpu').detach().numpy()\n    return b\n
"},{"location":"odak/tools/#odak.tools.convert_to_torch","title":"convert_to_torch(a, grad=True)","text":"

A definition to convert Numpy arrays to Torch.

Parameters:

  • a \u2013
         Input Numpy array.\n
  • grad \u2013
         Set if the converted array requires gradient.\n

Returns:

  • c ( Tensor ) \u2013

    Converted array.

Source code in odak/tools/conversions.py
def convert_to_torch(a, grad=True):\n    \"\"\"\n    A definition to convert Numpy arrays to Torch.\n\n    Parameters\n    ----------\n    a          : ndarray\n                 Input Numpy array.\n    grad       : bool\n                 Set if the converted array requires gradient.\n\n    Returns\n    ----------\n    c          : torch.Tensor\n                 Converted array.\n    \"\"\"\n    b = np.copy(a)\n    c = torch.from_numpy(b)\n    c.requires_grad_(grad)\n    return c\n
"},{"location":"odak/tools/#odak.tools.convolve2d","title":"convolve2d(field, kernel)","text":"

Definition to convolve a field with a kernel by multiplying in frequency space.

Parameters:

  • field \u2013
          Input field with MxN shape.\n
  • kernel \u2013
          Input kernel with MxN shape.\n

Returns:

  • new_field ( ndarray ) \u2013

    Convolved field.

Source code in odak/tools/matrix.py
def convolve2d(field, kernel):\n    \"\"\"\n    Definition to convolve a field with a kernel by multiplying in frequency space.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field with MxN shape.\n    kernel      : ndarray\n                  Input kernel with MxN shape.\n\n    Returns\n    ----------\n    new_field   : ndarray\n                  Convolved field.\n    \"\"\"\n    fr = np.fft.fft2(field)\n    fr2 = np.fft.fft2(np.flipud(np.fliplr(kernel)))\n    m, n = fr.shape\n    new_field = np.real(np.fft.ifft2(fr*fr2))\n    new_field = np.roll(new_field, int(-m/2+1), axis=0)\n    new_field = np.roll(new_field, int(-n/2+1), axis=1)\n    return new_field\n
"},{"location":"odak/tools/#odak.tools.copy_file","title":"copy_file(source, destination, follow_symlinks=True)","text":"

Definition to copy a file from one location to another.

Parameters:

  • source \u2013
              Source filename.\n
  • destination \u2013
              Destination filename.\n
  • follow_symlinks (bool, default: True ) \u2013
              Set to True to follow the source of symbolic links.\n
Source code in odak/tools/file.py
def copy_file(source, destination, follow_symlinks = True):\n    \"\"\"\n    Definition to copy a file from one location to another.\n\n\n\n    Parameters\n    ----------\n    source          : str\n                      Source filename.\n    destination     : str\n                      Destination filename.\n    follow_symlinks : bool\n                      Set to True to follow the source of symbolic links.\n    \"\"\"\n    return shutil.copyfile(\n                           expanduser(source),\n                           expanduser(source),\n                           follow_symlinks = follow_symlinks\n                          )\n
"},{"location":"odak/tools/#odak.tools.create_empty_list","title":"create_empty_list(dimensions=[1, 1])","text":"

A definition to create an empty Pythonic list.

Parameters:

  • dimensions \u2013
           Dimensions of the list to be created.\n

Returns:

  • new_list ( list ) \u2013

    New empty list.

Source code in odak/tools/matrix.py
def create_empty_list(dimensions = [1, 1]):\n    \"\"\"\n    A definition to create an empty Pythonic list.\n\n    Parameters\n    ----------\n    dimensions   : list\n                   Dimensions of the list to be created.\n\n    Returns\n    -------\n    new_list     : list\n                   New empty list.\n    \"\"\"\n    new_list = 0\n    for n in reversed(dimensions):\n        new_list = [new_list] * n\n    return new_list\n
"},{"location":"odak/tools/#odak.tools.create_ray_from_two_points","title":"create_ray_from_two_points(x0y0z0, x1y1z1)","text":"

Definition to create a ray from two given points. Note that both inputs must match in shape.

Parameters:

  • x0y0z0 \u2013
           List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n
  • x1y1z1 \u2013
           List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/raytracing/ray.py
def create_ray_from_two_points(x0y0z0, x1y1z1):\n    \"\"\"\n    Definition to create a ray from two given points. Note that both inputs must match in shape.\n\n    Parameters\n    ----------\n    x0y0z0       : list\n                   List that contains X,Y and Z start locations of a ray (3). It can also be a list of points as well (mx3). This is the starting point.\n    x1y1z1       : list\n                   List that contains X,Y and Z ending locations of a ray (3). It can also be a list of points as well (mx3). This is the end point.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    x0y0z0 = np.asarray(x0y0z0, dtype=np.float64)\n    x1y1z1 = np.asarray(x1y1z1, dtype=np.float64)\n    if len(x0y0z0.shape) == 1:\n        x0y0z0 = x0y0z0.reshape((1, 3))\n    if len(x1y1z1.shape) == 1:\n        x1y1z1 = x1y1z1.reshape((1, 3))\n    xdiff = x1y1z1[:, 0] - x0y0z0[:, 0]\n    ydiff = x1y1z1[:, 1] - x0y0z0[:, 1]\n    zdiff = x1y1z1[:, 2] - x0y0z0[:, 2]\n    s = np.sqrt(xdiff ** 2 + ydiff ** 2 + zdiff ** 2)\n    s[s == 0] = np.nan\n    cosines = np.zeros((xdiff.shape[0], 3))\n    cosines[:, 0] = xdiff/s\n    cosines[:, 1] = ydiff/s\n    cosines[:, 2] = zdiff/s\n    ray = np.zeros((xdiff.shape[0], 2, 3), dtype=np.float64)\n    ray[:, 0] = x0y0z0\n    ray[:, 1] = cosines\n    if ray.shape[0] == 1:\n        ray = ray.reshape((2, 3))\n    return ray\n
"},{"location":"odak/tools/#odak.tools.crop_center","title":"crop_center(field, size=None)","text":"

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

Parameters:

  • field \u2013
          Input field 2Mx2N array.\n

Returns:

  • cropped ( ndarray ) \u2013

    Cropped version of the input field.

Source code in odak/tools/matrix.py
def crop_center(field, size=None):\n    \"\"\"\n    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field 2Mx2N array.\n\n    Returns\n    ----------\n    cropped     : ndarray\n                  Cropped version of the input field.\n    \"\"\"\n    if type(size) == type(None):\n        qx = int(np.ceil(field.shape[0])/4)\n        qy = int(np.ceil(field.shape[1])/4)\n        cropped = np.copy(field[qx:3*qx, qy:3*qy])\n    else:\n        cx = int(np.ceil(field.shape[0]/2))\n        cy = int(np.ceil(field.shape[1]/2))\n        hx = int(np.ceil(size[0]/2))\n        hy = int(np.ceil(size[1]/2))\n        cropped = np.copy(field[cx-hx:cx+hx, cy-hy:cy+hy])\n    return cropped\n
"},{"location":"odak/tools/#odak.tools.cross_product","title":"cross_product(vector1, vector2)","text":"

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

Parameters:

  • vector1 \u2013
           A vector/ray.\n
  • vector2 \u2013
           A vector/ray.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/tools/vector.py
def cross_product(vector1, vector2):\n    \"\"\"\n    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product\n\n    Parameters\n    ----------\n    vector1      : ndarray\n                   A vector/ray.\n    vector2      : ndarray\n                   A vector/ray.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    angle = np.cross(vector1[1].T, vector2[1].T)\n    angle = np.asarray(angle)\n    ray = np.array([vector1[0], angle], dtype=np.float32)\n    return ray\n
"},{"location":"odak/tools/#odak.tools.distance_between_point_clouds","title":"distance_between_point_clouds(points0, points1)","text":"

A definition to find distance between every point in one cloud to other points in the other point cloud.

Parameters:

  • points0 \u2013
          Mx3 points.\n
  • points1 \u2013
          Nx3 points.\n

Returns:

  • distances ( ndarray ) \u2013

    MxN distances.

Source code in odak/tools/vector.py
def distance_between_point_clouds(points0, points1):\n    \"\"\"\n    A definition to find distance between every point in one cloud to other points in the other point cloud.\n    Parameters\n    ----------\n    points0     : ndarray\n                  Mx3 points.\n    points1     : ndarray\n                  Nx3 points.\n\n    Returns\n    ----------\n    distances   : ndarray\n                  MxN distances.\n    \"\"\"\n    c = points1.reshape((1, points1.shape[0], points1.shape[1]))\n    a = np.repeat(c, points0.shape[0], axis=0)\n    b = points0.reshape((points0.shape[0], 1, points0.shape[1]))\n    b = np.repeat(b, a.shape[1], axis=1)\n    distances = np.sqrt(np.sum((a-b)**2, axis=2))\n    return distances\n
"},{"location":"odak/tools/#odak.tools.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( float ) \u2013

    Distance in between given two points.

Source code in odak/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : list\n                  First point in X,Y,Z.\n    point2      : list\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : float\n                  Distance in between given two points.\n    \"\"\"\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = np.sqrt(np.sum((point1-point2)**2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))\n    return distance\n
"},{"location":"odak/tools/#odak.tools.expanduser","title":"expanduser(filename)","text":"

Definition to decode filename using namespaces and shortcuts.

Parameters:

  • filename \u2013
            Filename.\n

Returns:

  • new_filename ( str ) \u2013

    Filename.

Source code in odak/tools/file.py
def expanduser(filename):\n    \"\"\"\n    Definition to decode filename using namespaces and shortcuts.\n\n\n    Parameters\n    ----------\n    filename      : str\n                    Filename.\n\n\n    Returns\n    -------\n    new_filename  : str\n                    Filename.\n    \"\"\"\n    new_filename = os.path.expanduser(filename)\n    return new_filename\n
"},{"location":"odak/tools/#odak.tools.generate_2d_gaussian","title":"generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3])","text":"

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n

Returns:

  • kernel_2d ( ndarray ) \u2013

    Generated Gaussian kernel.

Source code in odak/tools/matrix.py
def generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]):\n    \"\"\"\n    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n\n    Returns\n    ----------\n    kernel_2d     : ndarray\n                    Generated Gaussian kernel.\n    \"\"\"\n    x = np.linspace(-nsigma[0], nsigma[0], kernel_length[0]+1)\n    y = np.linspace(-nsigma[1], nsigma[1], kernel_length[1]+1)\n    xx, yy = np.meshgrid(x, y)\n    kernel_2d = np.exp(-0.5*(np.square(xx) /\n                       np.square(nsigma[0]) + np.square(yy)/np.square(nsigma[1])))\n    kernel_2d = kernel_2d/kernel_2d.sum()\n    return kernel_2d\n
"},{"location":"odak/tools/#odak.tools.generate_bandlimits","title":"generate_bandlimits(size=[512, 512], levels=9)","text":"

A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.

Parameters:

  • size \u2013
         Size of each mask in octaves.\n

Returns:

  • masks ( ndarray ) \u2013

    Masks (Octaves).

Source code in odak/tools/matrix.py
def generate_bandlimits(size=[512, 512], levels=9):\n    \"\"\"\n    A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.\n\n    Parameters\n    ----------\n    size       : list\n                 Size of each mask in octaves.\n\n    Returns\n    ----------\n    masks      : ndarray\n                 Masks (Octaves).\n    \"\"\"\n    masks = np.zeros((levels, size[0], size[1]))\n    cx = int(size[0]/2)\n    cy = int(size[1]/2)\n    for i in range(0, masks.shape[0]):\n        deltax = int((size[0])/(2**(i+1)))\n        deltay = int((size[1])/(2**(i+1)))\n        masks[\n            i,\n            cx-deltax:cx+deltax,\n            cy-deltay:cy+deltay\n        ] = 1.\n        masks[\n            i,\n            int(cx-deltax/2.):int(cx+deltax/2.),\n            int(cy-deltay/2.):int(cy+deltay/2.)\n        ] = 0.\n    masks = np.asarray(masks)\n    return masks\n
"},{"location":"odak/tools/#odak.tools.grid_sample","title":"grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def grid_sample(no=[10, 10], size=[100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    step = [\n        size[0]/(no[0]-1),\n        size[1]/(no[1]-1)\n    ]\n    x, y = np.mgrid[0:no[0], 0:no[1]]\n    samples[:, :, 0] = x*step[0]-size[0]/2.\n    samples[:, :, 1] = y*step[1]-size[1]/2.\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1], samples.shape[2]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.list_files","title":"list_files(path, key='*.*', recursive=True)","text":"

Definition to list files in a given path with a given key.

Parameters:

  • path \u2013
          Path to a folder.\n
  • key \u2013
          Key used for scanning a path.\n
  • recursive \u2013
          If set True, scan the path recursively.\n

Returns:

  • files_list ( ndarray ) \u2013

    list of files found in a given path.

Source code in odak/tools/file.py
def list_files(path, key = '*.*', recursive = True):\n    \"\"\"\n    Definition to list files in a given path with a given key.\n\n\n    Parameters\n    ----------\n    path        : str\n                  Path to a folder.\n    key         : str\n                  Key used for scanning a path.\n    recursive   : bool\n                  If set True, scan the path recursively.\n\n\n    Returns\n    ----------\n    files_list  : ndarray\n                  list of files found in a given path.\n    \"\"\"\n    if recursive == True:\n        search_result = pathlib.Path(expanduser(path)).rglob(key)\n    elif recursive == False:\n        search_result = pathlib.Path(expanduser(path)).glob(key)\n    files_list = []\n    for item in search_result:\n        files_list.append(str(item))\n    files_list = sorted(files_list)\n    return files_list\n
"},{"location":"odak/tools/#odak.tools.load_dictionary","title":"load_dictionary(filename)","text":"

Definition to load a dictionary (JSON) file.

Parameters:

  • filename \u2013
            Filename.\n

Returns:

  • settings ( dict ) \u2013

    Dictionary read from the file.

Source code in odak/tools/file.py
def load_dictionary(filename):\n    \"\"\"\n    Definition to load a dictionary (JSON) file.\n\n\n    Parameters\n    ----------\n    filename      : str\n                    Filename.\n\n\n    Returns\n    ----------\n    settings      : dict\n                    Dictionary read from the file.\n\n    \"\"\"\n    settings = json.load(open(expanduser(filename)))\n    return settings\n
"},{"location":"odak/tools/#odak.tools.load_image","title":"load_image(fn, normalizeby=0.0, torch_style=False)","text":"

Definition to load an image from a given location as a Numpy array.

Parameters:

  • fn \u2013
           Filename.\n
  • normalizeby \u2013
           Value to to normalize images with. Default value of zero will lead to no normalization.\n
  • torch_style \u2013
           If set True, it will load an image mxnx3 as 3xmxn.\n

Returns:

  • image ( ndarray ) \u2013

    Image loaded as a Numpy array.

Source code in odak/tools/file.py
def load_image(fn, normalizeby = 0., torch_style = False):\n    \"\"\" \n    Definition to load an image from a given location as a Numpy array.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    normalizeby  : float\n                   Value to to normalize images with. Default value of zero will lead to no normalization.\n    torch_style  : bool\n                   If set True, it will load an image mxnx3 as 3xmxn.\n\n\n    Returns\n    ----------\n    image        :  ndarray\n                    Image loaded as a Numpy array.\n\n    \"\"\"\n    image = cv2.imread(expanduser(fn), cv2.IMREAD_UNCHANGED)\n    if isinstance(image, type(None)):\n         logging.warning('Image not properly loaded. Check filename or image type.')    \n         sys.exit()\n    if len(image.shape) > 2:\n        new_image = np.copy(image)\n        new_image[:, :, 0] = image[:, :, 2]\n        new_image[:, :, 2] = image[:, :, 0]\n        image = new_image\n    if normalizeby != 0.:\n        image = image * 1. / normalizeby\n    if torch_style == True and len(image.shape) > 2:\n        image = np.moveaxis(image, -1, 0)\n    return image.astype(float)\n
"},{"location":"odak/tools/#odak.tools.nufft2","title":"nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Size.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    Inverse NUFFT of the input field.

Source code in odak/tools/matrix.py
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list\n                  Size.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  Inverse NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    result = finufft.nufft2d2(\n        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)\n    if type(size) == type(None):\n        result = result.reshape(field.shape)\n    else:\n        result = result.reshape(size)\n    return result\n
"},{"location":"odak/tools/#odak.tools.nuifft2","title":"nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Shape of the NUFFT calculated for an input field.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    NUFFT of the input field.

Source code in odak/tools/matrix.py
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list or ndarray\n                  Shape of the NUFFT calculated for an input field.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    if type(size) == type(None):\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            image.shape,\n            eps=eps,\n            isign=sign\n        )\n    else:\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            (size[0], size[1]),\n            eps=eps,\n            isign=sign\n        )\n    result = np.asarray(result)\n    return result\n
"},{"location":"odak/tools/#odak.tools.point_to_ray_distance","title":"point_to_ray_distance(point, ray_point_0, ray_point_1)","text":"

Definition to find point's closest distance to a line represented with two points.

Parameters:

  • point \u2013
          Point to be tested.\n
  • ray_point_0 (ndarray) \u2013
          First point to represent a line.\n
  • ray_point_1 (ndarray) \u2013
          Second point to represent a line.\n

Returns:

  • distance ( float ) \u2013

    Calculated distance.

Source code in odak/tools/vector.py
def point_to_ray_distance(point, ray_point_0, ray_point_1):\n    \"\"\"\n    Definition to find point's closest distance to a line represented with two points.\n\n    Parameters\n    ----------\n    point       : ndarray\n                  Point to be tested.\n    ray_point_0 : ndarray\n                  First point to represent a line.\n    ray_point_1 : ndarray\n                  Second point to represent a line.\n\n    Returns\n    ----------\n    distance    : float\n                  Calculated distance.\n    \"\"\"\n    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))\n                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)\n    return distance\n
"},{"location":"odak/tools/#odak.tools.quantize","title":"quantize(image_field, bits=4)","text":"

Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.

Parameters:

  • image_field (ndarray) \u2013
          Input image field.\n
  • bits \u2013
          A value in between 0 to 8. Can not be zero.\n

Returns:

  • new_field ( ndarray ) \u2013

    Quantized image field.

Source code in odak/tools/matrix.py
def quantize(image_field, bits=4):\n    \"\"\"\n    Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.\n\n    Parameters\n    ----------\n    image_field : ndarray\n                  Input image field.\n    bits        : int\n                  A value in between 0 to 8. Can not be zero.\n\n    Returns\n    ----------\n    new_field   : ndarray\n                  Quantized image field.\n    \"\"\"\n    divider = 2**(8-bits)\n    new_field = image_field/divider\n    new_field = new_field.astype(np.int64)\n    return new_field\n
"},{"location":"odak/tools/#odak.tools.random_sample_point_cloud","title":"random_sample_point_cloud(point_cloud, no, p=None)","text":"

Definition to pull a subset of points from a point cloud with a given probability.

Parameters:

  • point_cloud \u2013
           Point cloud array.\n
  • no \u2013
           Number of samples.\n
  • p \u2013
           Probability list in the same size as no.\n

Returns:

  • subset ( ndarray ) \u2013

    Subset of the given point cloud.

Source code in odak/tools/sample.py
def random_sample_point_cloud(point_cloud, no, p=None):\n    \"\"\"\n    Definition to pull a subset of points from a point cloud with a given probability.\n\n    Parameters\n    ----------\n    point_cloud  : ndarray\n                   Point cloud array.\n    no           : list\n                   Number of samples.\n    p            : list\n                   Probability list in the same size as no.\n\n    Returns\n    ----------\n    subset       : ndarray\n                   Subset of the given point cloud.\n    \"\"\"\n    choice = np.random.choice(point_cloud.shape[0], no, p)\n    subset = point_cloud[choice, :]\n    return subset\n
"},{"location":"odak/tools/#odak.tools.read_PLY","title":"read_PLY(fn, offset=[0, 0, 0], angles=[0.0, 0.0, 0.0], mode='XYZ')","text":"

Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.

Parameters:

  • fn \u2013
           Filename of a PLY file.\n
  • offset \u2013
           Offset in X,Y,Z.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n

Returns:

  • triangles ( ndarray ) \u2013

    Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.

Source code in odak/tools/asset.py
def read_PLY(fn, offset=[0, 0, 0], angles=[0., 0., 0.], mode='XYZ'):\n    \"\"\"\n    Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    fn           : string\n                   Filename of a PLY file.\n    offset       : ndarray\n                   Offset in X,Y,Z.\n    angles       : list\n                   Rotation angles in degrees.\n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes. \n\n    Returns\n    ----------\n    triangles    : ndarray\n                  Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.\n    \"\"\"\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n    else:\n        np_ply = np\n    with open(fn, 'rb') as f:\n        plydata = PlyData.read(f)\n    triangle_ids = np_ply.vstack(plydata['face'].data['vertex_indices'])\n    triangles = []\n    for vertex_ids in triangle_ids:\n        triangle = [\n            rotate_point(plydata['vertex'][int(vertex_ids[0])\n                                           ].tolist(), angles=angles, offset=offset)[0],\n            rotate_point(plydata['vertex'][int(vertex_ids[1])\n                                           ].tolist(), angles=angles, offset=offset)[0],\n            rotate_point(plydata['vertex'][int(vertex_ids[2])\n                                           ].tolist(), angles=angles, offset=offset)[0]\n        ]\n        triangle = np_ply.asarray(triangle)\n        triangles.append(triangle)\n    triangles = np_ply.array(triangles)\n    triangles = np.asarray(triangles, dtype=np.float32)\n    return triangles\n
"},{"location":"odak/tools/#odak.tools.read_PLY_point_cloud","title":"read_PLY_point_cloud(filename)","text":"

Definition to read a PLY file as a point cloud.

Parameters:

  • filename \u2013
           Filename of a PLY file.\n

Returns:

  • point_cloud ( ndarray ) \u2013

    An array filled with poitns from the PLY file.

Source code in odak/tools/asset.py
def read_PLY_point_cloud(filename):\n    \"\"\"\n    Definition to read a PLY file as a point cloud.\n\n    Parameters\n    ----------\n    filename     : str\n                   Filename of a PLY file.\n\n    Returns\n    ----------\n    point_cloud  : ndarray\n                   An array filled with poitns from the PLY file.\n    \"\"\"\n    plydata = PlyData.read(filename)\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n        point_cloud = np_ply.zeros((plydata['vertex'][:].shape[0], 3))\n        point_cloud[:, 0] = np_ply.asarray(plydata['vertex']['x'][:])\n        point_cloud[:, 1] = np_ply.asarray(plydata['vertex']['y'][:])\n        point_cloud[:, 2] = np_ply.asarray(plydata['vertex']['z'][:])\n        point_cloud = np.asarray(point_cloud)\n    else:\n        point_cloud = np.zeros((plydata['vertex'][:].shape[0], 3))\n        point_cloud[:, 0] = np.asarray(plydata['vertex']['x'][:])\n        point_cloud[:, 1] = np.asarray(plydata['vertex']['y'][:])\n        point_cloud[:, 2] = np.asarray(plydata['vertex']['z'][:])\n    return point_cloud\n
"},{"location":"odak/tools/#odak.tools.read_text_file","title":"read_text_file(filename)","text":"

Definition to read a given text file and convert it into a Pythonic list.

Parameters:

  • filename \u2013
              Source filename (i.e. test.txt).\n

Returns:

  • content ( list ) \u2013

    Pythonic string list containing the text from the file provided.

Source code in odak/tools/file.py
def read_text_file(filename):\n    \"\"\"\n    Definition to read a given text file and convert it into a Pythonic list.\n\n\n    Parameters\n    ----------\n    filename        : str\n                      Source filename (i.e. test.txt).\n\n\n    Returns\n    -------\n    content         : list\n                      Pythonic string list containing the text from the file provided.\n    \"\"\"\n    content = []\n    loaded_file = open(expanduser(filename))\n    while line := loaded_file.readline():\n        content.append(line.rstrip())\n    return content\n
"},{"location":"odak/tools/#odak.tools.resize_image","title":"resize_image(img, target_size)","text":"

Definition to resize a given image to a target shape.

Parameters:

  • img \u2013
            MxN image to be resized.\n        Image must be normalized (0-1).\n
  • target_size \u2013
            Target shape.\n

Returns:

  • img ( ndarray ) \u2013

    Resized image.

Source code in odak/tools/file.py
def resize_image(img, target_size):\n    \"\"\"\n    Definition to resize a given image to a target shape.\n\n\n    Parameters\n    ----------\n    img           : ndarray\n                    MxN image to be resized.\n                    Image must be normalized (0-1).\n    target_size   : list\n                    Target shape.\n\n\n    Returns\n    ----------\n    img           : ndarray\n                    Resized image.\n\n    \"\"\"\n    img = cv2.resize(img, dsize=(target_size[0], target_size[1]), interpolation=cv2.INTER_AREA)\n    return img\n
"},{"location":"odak/tools/#odak.tools.rotate_point","title":"rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

  • rotx ( ndarray ) \u2013

    Rotation matrix along X axis.

  • roty ( ndarray ) \u2013

    Rotation matrix along Y axis.

  • rotz ( ndarray ) \u2013

    Rotation matrix along Z axis.

Source code in odak/tools/transformation.py
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : ndarray\n                   A point.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation\n    rotx         : ndarray\n                   Rotation matrix along X axis.\n    roty         : ndarray\n                   Rotation matrix along Y axis.\n    rotz         : ndarray\n                   Rotation matrix along Z axis.\n    \"\"\"\n    point = np.asarray(point)\n    point -= np.asarray(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))\n    result += np.asarray(origin)\n    result += np.asarray(offset)\n    return result, rotx, roty, rotz\n
"},{"location":"odak/tools/#odak.tools.rotate_points","title":"rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate points.

Parameters:

  • points \u2013
           Points.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

Source code in odak/tools/transformation.py
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate points.\n\n    Parameters\n    ----------\n    points       : ndarray\n                   Points.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation   \n    \"\"\"\n    points = np.asarray(points)\n    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:\n        result = np.array(offset) + points\n        return result\n    points -= np.array(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T\n    result += np.array(origin)\n    result += np.array(offset)\n    return result\n
"},{"location":"odak/tools/#odak.tools.rotmatx","title":"rotmatx(angle)","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotx ( ndarray ) \u2013

    Rotation matrix along X axis.

Source code in odak/tools/transformation.py
def rotmatx(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    rotx         : ndarray\n                   Rotation matrix along X axis.\n    \"\"\"\n    angle = np.float64(angle)\n    angle = np.radians(angle)\n    rotx = np.array([\n        [1.,               0.,               0.],\n        [0.,  math.cos(angle), -math.sin(angle)],\n        [0.,  math.sin(angle),  math.cos(angle)]\n    ], dtype=np.float64)\n    return rotx\n
"},{"location":"odak/tools/#odak.tools.rotmaty","title":"rotmaty(angle)","text":"

Definition to generate a rotation matrix along Y axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • roty ( ndarray ) \u2013

    Rotation matrix along Y axis.

Source code in odak/tools/transformation.py
def rotmaty(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Y axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    roty         : ndarray\n                   Rotation matrix along Y axis.\n    \"\"\"\n    angle = np.radians(angle)\n    roty = np.array([\n        [math.cos(angle),  0., math.sin(angle)],\n        [0.,               1.,              0.],\n        [-math.sin(angle), 0., math.cos(angle)]\n    ], dtype=np.float64)\n    return roty\n
"},{"location":"odak/tools/#odak.tools.rotmatz","title":"rotmatz(angle)","text":"

Definition to generate a rotation matrix along Z axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotz ( ndarray ) \u2013

    Rotation matrix along Z axis.

Source code in odak/tools/transformation.py
def rotmatz(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Z axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    rotz         : ndarray\n                   Rotation matrix along Z axis.\n    \"\"\"\n    angle = np.radians(angle)\n    rotz = np.array([\n        [math.cos(angle), -math.sin(angle), 0.],\n        [math.sin(angle),  math.cos(angle), 0.],\n        [0.,               0., 1.]\n    ], dtype=np.float64)\n\n    return rotz\n
"},{"location":"odak/tools/#odak.tools.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = np.subtract(b, a)\n    p1a = np.subtract(p1, a)\n    p2a = np.subtract(p2, a)\n    cp1 = np.cross(ba, p1a)\n    cp2 = np.cross(ba, p2a)\n    test = np.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/tools/#odak.tools.save_dictionary","title":"save_dictionary(settings, filename)","text":"

Definition to load a dictionary (JSON) file.

Parameters:

  • settings \u2013
            Dictionary read from the file.\n
  • filename \u2013
            Filename.\n
Source code in odak/tools/file.py
def save_dictionary(settings, filename):\n    \"\"\"\n    Definition to load a dictionary (JSON) file.\n\n\n    Parameters\n    ----------\n    settings      : dict\n                    Dictionary read from the file.\n    filename      : str\n                    Filename.\n    \"\"\"\n    with open(expanduser(filename), 'w', encoding='utf-8') as f:\n        json.dump(settings, f, ensure_ascii=False, indent=4)\n    return settings\n
"},{"location":"odak/tools/#odak.tools.save_image","title":"save_image(fn, img, cmin=0, cmax=255, color_depth=8)","text":"

Definition to save a Numpy array as an image.

Parameters:

  • fn \u2013
           Filename.\n
  • img \u2013
           A numpy array with NxMx3 or NxMx1 shapes.\n
  • cmin \u2013
           Minimum value that will be interpreted as 0 level in the final image.\n
  • cmax \u2013
           Maximum value that will be interpreted as 255 level in the final image.\n
  • color_depth \u2013
           Pixel color depth in bits, default is eight bits.\n

Returns:

  • bool ( bool ) \u2013

    True if successful.

Source code in odak/tools/file.py
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):\n    \"\"\"\n    Definition to save a Numpy array as an image.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    img          : ndarray\n                   A numpy array with NxMx3 or NxMx1 shapes.\n    cmin         : int\n                   Minimum value that will be interpreted as 0 level in the final image.\n    cmax         : int\n                   Maximum value that will be interpreted as 255 level in the final image.\n    color_depth  : int\n                   Pixel color depth in bits, default is eight bits.\n\n\n    Returns\n    ----------\n    bool         :  bool\n                    True if successful.\n\n    \"\"\"\n    input_img = np.copy(img).astype(np.float32)\n    cmin = float(cmin)\n    cmax = float(cmax)\n    input_img[input_img < cmin] = cmin\n    input_img[input_img > cmax] = cmax\n    input_img /= cmax\n    input_img = input_img * 1. * (2**color_depth - 1)\n    if color_depth == 8:\n        input_img = input_img.astype(np.uint8)\n    elif color_depth == 16:\n        input_img = input_img.astype(np.uint16)\n    if len(input_img.shape) > 2:\n        if input_img.shape[2] > 1:\n            cache_img = np.copy(input_img)\n            cache_img[:, :, 0] = input_img[:, :, 2]\n            cache_img[:, :, 2] = input_img[:, :, 0]\n            input_img = cache_img\n    cv2.imwrite(expanduser(fn), input_img)\n    return True\n
"},{"location":"odak/tools/#odak.tools.shell_command","title":"shell_command(cmd, cwd='.', timeout=None, check=True)","text":"

Definition to initiate shell commands.

Parameters:

  • cmd \u2013
           Command to be executed.\n
  • cwd \u2013
           Working directory.\n
  • timeout \u2013
           Timeout if the process isn't complete in the given number of seconds.\n
  • check \u2013
           Set it to True to return the results and to enable timeout.\n

Returns:

  • proc ( Popen ) \u2013

    Generated process.

  • outs ( str ) \u2013

    Outputs of the executed command, returns None when check is set to False.

  • errs ( str ) \u2013

    Errors of the executed command, returns None when check is set to False.

Source code in odak/tools/file.py
def shell_command(cmd, cwd = '.', timeout = None, check = True):\n    \"\"\"\n    Definition to initiate shell commands.\n\n\n    Parameters\n    ----------\n    cmd          : list\n                   Command to be executed. \n    cwd          : str\n                   Working directory.\n    timeout      : int\n                   Timeout if the process isn't complete in the given number of seconds.\n    check        : bool\n                   Set it to True to return the results and to enable timeout.\n\n\n    Returns\n    ----------\n    proc         : subprocess.Popen\n                   Generated process.\n    outs         : str\n                   Outputs of the executed command, returns None when check is set to False.\n    errs         : str\n                   Errors of the executed command, returns None when check is set to False.\n\n    \"\"\"\n    for item_id in range(len(cmd)):\n        cmd[item_id] = expanduser(cmd[item_id])\n    proc = subprocess.Popen(\n                            cmd,\n                            cwd = cwd,\n                            stdout = subprocess.PIPE\n                           )\n    if check == False:\n        return proc, None, None\n    try:\n        outs, errs = proc.communicate(timeout = timeout)\n    except subprocess.TimeoutExpired:\n        proc.kill()\n        outs, errs = proc.communicate()\n    return proc, outs, errs\n
"},{"location":"odak/tools/#odak.tools.size_of_a_file","title":"size_of_a_file(file_path)","text":"

A definition to get size of a file with a relevant unit.

Parameters:

  • file_path \u2013
         Path of the file.\n

Returns:

  • a ( float ) \u2013

    Size of the file.

  • b ( str ) \u2013

    Unit of the size (bytes, KB, MB, GB or TB).

Source code in odak/tools/file.py
def size_of_a_file(file_path):\n    \"\"\"\n    A definition to get size of a file with a relevant unit.\n\n\n    Parameters\n    ----------\n    file_path  : float\n                 Path of the file.\n\n\n    Returns\n    ----------\n    a          : float\n                 Size of the file.\n    b          : str\n                 Unit of the size (bytes, KB, MB, GB or TB).\n    \"\"\"\n    if os.path.isfile(file_path):\n        file_info = os.stat(file_path)\n        a, b = convert_bytes(file_info.st_size)\n        return a, b\n    return None, None\n
"},{"location":"odak/tools/#odak.tools.sphere_sample","title":"sphere_sample(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2])","text":"

Definition to generate a regular sample set on the surface of a sphere using polar coordinates.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of a sphere.\n
  • center \u2013
          Center of a sphere.\n
  • k \u2013
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def sphere_sample(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):\n    \"\"\"\n    Definition to generate a regular sample set on the surface of a sphere using polar coordinates.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of a sphere.\n    center      : list\n                  Center of a sphere.\n    k           : list\n                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    psi, teta = np.mgrid[0:no[0], 0:no[1]]\n    psi = k[0]*np.pi/no[0]*psi\n    teta = k[1]*np.pi/no[1]*teta\n    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)\n    samples[:, :, 1] = center[0]+radius*np.sin(psi)*np.sin(teta)\n    samples[:, :, 2] = center[0]+radius*np.cos(psi)\n    samples = samples.reshape((no[0]*no[1], 3))\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sphere_sample_uniform","title":"sphere_sample_uniform(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2])","text":"

Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of a sphere.\n
  • center \u2013
          Center of a sphere.\n
  • k \u2013
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def sphere_sample_uniform(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):\n    \"\"\"\n    Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of a sphere.\n    center      : list\n                  Center of a sphere.\n    k           : list\n                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    row = np.arange(0, no[0])\n    psi, teta = np.mgrid[0:no[0], 0:no[1]]\n    for psi_id in range(0, no[0]):\n        psi[psi_id] = np.roll(row, psi_id, axis=0)\n        teta[psi_id] = np.roll(row, -psi_id, axis=0)\n    psi = k[0]*np.pi/no[0]*psi\n    teta = k[1]*np.pi/no[1]*teta\n    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)\n    samples[:, :, 1] = center[1]+radius*np.sin(psi)*np.sin(teta)\n    samples[:, :, 2] = center[2]+radius*np.cos(psi)\n    samples = samples.reshape((no[0]*no[1], 3))\n    return samples\n
"},{"location":"odak/tools/#odak.tools.tilt_towards","title":"tilt_towards(location, lookat)","text":"

Definition to tilt surface normal of a plane towards a point.

Parameters:

  • location \u2013
           Center of the plane to be tilted.\n
  • lookat \u2013
           Tilt towards this point.\n

Returns:

  • angles ( list ) \u2013

    Rotation angles in degrees.

Source code in odak/tools/transformation.py
def tilt_towards(location, lookat):\n    \"\"\"\n    Definition to tilt surface normal of a plane towards a point.\n\n    Parameters\n    ----------\n    location     : list\n                   Center of the plane to be tilted.\n    lookat       : list\n                   Tilt towards this point.\n\n    Returns\n    ----------\n    angles       : list\n                   Rotation angles in degrees.\n    \"\"\"\n    dx = location[0]-lookat[0]\n    dy = location[1]-lookat[1]\n    dz = location[2]-lookat[2]\n    dist = np.sqrt(dx**2+dy**2+dz**2)\n    phi = np.arctan2(dy, dx)\n    theta = np.arccos(dz/dist)\n    angles = [\n        0,\n        np.degrees(theta).tolist(),\n        np.degrees(phi).tolist()\n    ]\n    return angles\n
"},{"location":"odak/tools/#odak.tools.write_PLY","title":"write_PLY(triangles, savefn='output.ply')","text":"

Definition to generate a PLY file from given points.

Parameters:

  • triangles \u2013
          List of triangles with the size of Mx3x3.\n
  • savefn \u2013
          Filename for a PLY file.\n
Source code in odak/tools/asset.py
def write_PLY(triangles, savefn = 'output.ply'):\n    \"\"\"\n    Definition to generate a PLY file from given points.\n\n    Parameters\n    ----------\n    triangles   : ndarray\n                  List of triangles with the size of Mx3x3.\n    savefn      : string\n                  Filename for a PLY file.\n    \"\"\"\n    tris = []\n    pnts = []\n    color = [255, 255, 255]\n    for tri_id in range(triangles.shape[0]):\n        tris.append(\n            (\n                [3*tri_id, 3*tri_id+1, 3*tri_id+2],\n                color[0],\n                color[1],\n                color[2]\n            )\n        )\n        for i in range(0, 3):\n            pnts.append(\n                (\n                    float(triangles[tri_id][i][0]),\n                    float(triangles[tri_id][i][1]),\n                    float(triangles[tri_id][i][2])\n                )\n            )\n    tris = np.asarray(tris, dtype=[\n                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    # Save mesh.\n    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])\n    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])\n    PlyData([el1, el2], text=\"True\").write(savefn)\n
"},{"location":"odak/tools/#odak.tools.write_PLY_from_points","title":"write_PLY_from_points(points, savefn='output.ply')","text":"

Definition to generate a PLY file from given points.

Parameters:

  • points \u2013
          List of points with the size of MxNx3.\n
  • savefn \u2013
          Filename for a PLY file.\n
Source code in odak/tools/asset.py
def write_PLY_from_points(points, savefn='output.ply'):\n    \"\"\"\n    Definition to generate a PLY file from given points.\n\n    Parameters\n    ----------\n    points      : ndarray\n                  List of points with the size of MxNx3.\n    savefn      : string\n                  Filename for a PLY file.\n\n    \"\"\"\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n    else:\n        np_ply = np\n    # Generate equation\n    samples = [points.shape[0], points.shape[1]]\n    # Generate vertices.\n    pnts = []\n    tris = []\n    for idx in range(0, samples[0]):\n        for idy in range(0, samples[1]):\n            pnt = (points[idx, idy, 0],\n                   points[idx, idy, 1], points[idx, idy, 2])\n            pnts.append(pnt)\n    color = [255, 255, 255]\n    for idx in range(0, samples[0]-1):\n        for idy in range(0, samples[1]-1):\n            tris.append(([idy+(idx+1)*samples[0], idy+idx*samples[0],\n                        idy+1+idx*samples[0]], color[0], color[1], color[2]))\n            tris.append(([idy+(idx+1)*samples[0], idy+1+idx*samples[0],\n                        idy+1+(idx+1)*samples[0]], color[0], color[1], color[2]))\n    tris = np_ply.asarray(tris, dtype=[(\n        'vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n    pnts = np_ply.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    # Save mesh.\n    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])\n    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])\n    PlyData([el1, el2], text=\"True\").write(savefn)\n
"},{"location":"odak/tools/#odak.tools.write_to_text_file","title":"write_to_text_file(content, filename, write_flag='w')","text":"

Defininition to write a Pythonic list to a text file.

Parameters:

  • content \u2013
              Pythonic string list to be written to a file.\n
  • filename \u2013
              Destination filename (i.e. test.txt).\n
  • write_flag \u2013
              Defines the interaction with the file. \n          The default is \"w\" (overwrite any existing content).\n          For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files\n
Source code in odak/tools/file.py
def write_to_text_file(content, filename, write_flag = 'w'):\n    \"\"\"\n    Defininition to write a Pythonic list to a text file.\n\n\n    Parameters\n    ----------\n    content         : list\n                      Pythonic string list to be written to a file.\n    filename        : str\n                      Destination filename (i.e. test.txt).\n    write_flag      : str\n                      Defines the interaction with the file. \n                      The default is \"w\" (overwrite any existing content).\n                      For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files\n    \"\"\"\n    with open(expanduser(filename), write_flag) as f:\n        for line in content:\n            f.write('{}\\n'.format(line))\n    return True\n
"},{"location":"odak/tools/#odak.tools.zero_pad","title":"zero_pad(field, size=None, method='center')","text":"

Definition to zero pad a MxN array to 2Mx2N array.

Parameters:

  • field \u2013
                Input field MxN array.\n
  • size \u2013
                Size to be zeropadded.\n
  • method \u2013
                Zeropad either by placing the content to center or to the left.\n

Returns:

  • field_zero_padded ( ndarray ) \u2013

    Zeropadded version of the input field.

Source code in odak/tools/matrix.py
def zero_pad(field, size=None, method='center'):\n    \"\"\"\n    Definition to zero pad a MxN array to 2Mx2N array.\n\n    Parameters\n    ----------\n    field             : ndarray\n                        Input field MxN array.\n    size              : list\n                        Size to be zeropadded.\n    method            : str\n                        Zeropad either by placing the content to center or to the left.\n\n    Returns\n    ----------\n    field_zero_padded : ndarray\n                        Zeropadded version of the input field.\n    \"\"\"\n    if type(size) == type(None):\n        hx = int(np.ceil(field.shape[0])/2)\n        hy = int(np.ceil(field.shape[1])/2)\n    else:\n        hx = int(np.ceil((size[0]-field.shape[0])/2))\n        hy = int(np.ceil((size[1]-field.shape[1])/2))\n    if method == 'center':\n        field_zero_padded = np.pad(\n            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))\n    elif method == 'left aligned':\n        field_zero_padded = np.pad(\n            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))\n    if type(size) != type(None):\n        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]\n    return field_zero_padded\n
"},{"location":"odak/tools/#odak.tools.asset.read_PLY","title":"read_PLY(fn, offset=[0, 0, 0], angles=[0.0, 0.0, 0.0], mode='XYZ')","text":"

Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.

Parameters:

  • fn \u2013
           Filename of a PLY file.\n
  • offset \u2013
           Offset in X,Y,Z.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n

Returns:

  • triangles ( ndarray ) \u2013

    Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.

Source code in odak/tools/asset.py
def read_PLY(fn, offset=[0, 0, 0], angles=[0., 0., 0.], mode='XYZ'):\n    \"\"\"\n    Definition to read a PLY file and extract meshes from a given PLY file. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    fn           : string\n                   Filename of a PLY file.\n    offset       : ndarray\n                   Offset in X,Y,Z.\n    angles       : list\n                   Rotation angles in degrees.\n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes. \n\n    Returns\n    ----------\n    triangles    : ndarray\n                  Triangles from a given PLY file. Note that the triangles coming out of this function isn't always structured in the right order and with the size of (MxN)x3. You can use numpy's reshape to restructure it to mxnx3 if you know what you are doing.\n    \"\"\"\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n    else:\n        np_ply = np\n    with open(fn, 'rb') as f:\n        plydata = PlyData.read(f)\n    triangle_ids = np_ply.vstack(plydata['face'].data['vertex_indices'])\n    triangles = []\n    for vertex_ids in triangle_ids:\n        triangle = [\n            rotate_point(plydata['vertex'][int(vertex_ids[0])\n                                           ].tolist(), angles=angles, offset=offset)[0],\n            rotate_point(plydata['vertex'][int(vertex_ids[1])\n                                           ].tolist(), angles=angles, offset=offset)[0],\n            rotate_point(plydata['vertex'][int(vertex_ids[2])\n                                           ].tolist(), angles=angles, offset=offset)[0]\n        ]\n        triangle = np_ply.asarray(triangle)\n        triangles.append(triangle)\n    triangles = np_ply.array(triangles)\n    triangles = np.asarray(triangles, dtype=np.float32)\n    return triangles\n
"},{"location":"odak/tools/#odak.tools.asset.read_PLY_point_cloud","title":"read_PLY_point_cloud(filename)","text":"

Definition to read a PLY file as a point cloud.

Parameters:

  • filename \u2013
           Filename of a PLY file.\n

Returns:

  • point_cloud ( ndarray ) \u2013

    An array filled with poitns from the PLY file.

Source code in odak/tools/asset.py
def read_PLY_point_cloud(filename):\n    \"\"\"\n    Definition to read a PLY file as a point cloud.\n\n    Parameters\n    ----------\n    filename     : str\n                   Filename of a PLY file.\n\n    Returns\n    ----------\n    point_cloud  : ndarray\n                   An array filled with poitns from the PLY file.\n    \"\"\"\n    plydata = PlyData.read(filename)\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n        point_cloud = np_ply.zeros((plydata['vertex'][:].shape[0], 3))\n        point_cloud[:, 0] = np_ply.asarray(plydata['vertex']['x'][:])\n        point_cloud[:, 1] = np_ply.asarray(plydata['vertex']['y'][:])\n        point_cloud[:, 2] = np_ply.asarray(plydata['vertex']['z'][:])\n        point_cloud = np.asarray(point_cloud)\n    else:\n        point_cloud = np.zeros((plydata['vertex'][:].shape[0], 3))\n        point_cloud[:, 0] = np.asarray(plydata['vertex']['x'][:])\n        point_cloud[:, 1] = np.asarray(plydata['vertex']['y'][:])\n        point_cloud[:, 2] = np.asarray(plydata['vertex']['z'][:])\n    return point_cloud\n
"},{"location":"odak/tools/#odak.tools.asset.write_PLY","title":"write_PLY(triangles, savefn='output.ply')","text":"

Definition to generate a PLY file from given points.

Parameters:

  • triangles \u2013
          List of triangles with the size of Mx3x3.\n
  • savefn \u2013
          Filename for a PLY file.\n
Source code in odak/tools/asset.py
def write_PLY(triangles, savefn = 'output.ply'):\n    \"\"\"\n    Definition to generate a PLY file from given points.\n\n    Parameters\n    ----------\n    triangles   : ndarray\n                  List of triangles with the size of Mx3x3.\n    savefn      : string\n                  Filename for a PLY file.\n    \"\"\"\n    tris = []\n    pnts = []\n    color = [255, 255, 255]\n    for tri_id in range(triangles.shape[0]):\n        tris.append(\n            (\n                [3*tri_id, 3*tri_id+1, 3*tri_id+2],\n                color[0],\n                color[1],\n                color[2]\n            )\n        )\n        for i in range(0, 3):\n            pnts.append(\n                (\n                    float(triangles[tri_id][i][0]),\n                    float(triangles[tri_id][i][1]),\n                    float(triangles[tri_id][i][2])\n                )\n            )\n    tris = np.asarray(tris, dtype=[\n                          ('vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n    pnts = np.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    # Save mesh.\n    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])\n    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])\n    PlyData([el1, el2], text=\"True\").write(savefn)\n
"},{"location":"odak/tools/#odak.tools.asset.write_PLY_from_points","title":"write_PLY_from_points(points, savefn='output.ply')","text":"

Definition to generate a PLY file from given points.

Parameters:

  • points \u2013
          List of points with the size of MxNx3.\n
  • savefn \u2013
          Filename for a PLY file.\n
Source code in odak/tools/asset.py
def write_PLY_from_points(points, savefn='output.ply'):\n    \"\"\"\n    Definition to generate a PLY file from given points.\n\n    Parameters\n    ----------\n    points      : ndarray\n                  List of points with the size of MxNx3.\n    savefn      : string\n                  Filename for a PLY file.\n\n    \"\"\"\n    if np.__name__ != 'numpy':\n        import numpy as np_ply\n    else:\n        np_ply = np\n    # Generate equation\n    samples = [points.shape[0], points.shape[1]]\n    # Generate vertices.\n    pnts = []\n    tris = []\n    for idx in range(0, samples[0]):\n        for idy in range(0, samples[1]):\n            pnt = (points[idx, idy, 0],\n                   points[idx, idy, 1], points[idx, idy, 2])\n            pnts.append(pnt)\n    color = [255, 255, 255]\n    for idx in range(0, samples[0]-1):\n        for idy in range(0, samples[1]-1):\n            tris.append(([idy+(idx+1)*samples[0], idy+idx*samples[0],\n                        idy+1+idx*samples[0]], color[0], color[1], color[2]))\n            tris.append(([idy+(idx+1)*samples[0], idy+1+idx*samples[0],\n                        idy+1+(idx+1)*samples[0]], color[0], color[1], color[2]))\n    tris = np_ply.asarray(tris, dtype=[(\n        'vertex_indices', 'i4', (3,)), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n    pnts = np_ply.asarray(pnts, dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    # Save mesh.\n    el1 = PlyElement.describe(pnts, 'vertex', comments=['Vertex data'])\n    el2 = PlyElement.describe(tris, 'face', comments=['Face data'])\n    PlyData([el1, el2], text=\"True\").write(savefn)\n
"},{"location":"odak/tools/#odak.tools.conversions.convert_to_numpy","title":"convert_to_numpy(a)","text":"

A definition to convert Torch to Numpy.

Parameters:

  • a \u2013
         Input Torch array.\n

Returns:

  • b ( ndarray ) \u2013

    Converted array.

Source code in odak/tools/conversions.py
def convert_to_numpy(a):\n    \"\"\"\n    A definition to convert Torch to Numpy.\n\n    Parameters\n    ----------\n    a          : torch.Tensor\n                 Input Torch array.\n\n    Returns\n    ----------\n    b          : numpy.ndarray\n                 Converted array.\n    \"\"\"\n    b = a.to('cpu').detach().numpy()\n    return b\n
"},{"location":"odak/tools/#odak.tools.conversions.convert_to_torch","title":"convert_to_torch(a, grad=True)","text":"

A definition to convert Numpy arrays to Torch.

Parameters:

  • a \u2013
         Input Numpy array.\n
  • grad \u2013
         Set if the converted array requires gradient.\n

Returns:

  • c ( Tensor ) \u2013

    Converted array.

Source code in odak/tools/conversions.py
def convert_to_torch(a, grad=True):\n    \"\"\"\n    A definition to convert Numpy arrays to Torch.\n\n    Parameters\n    ----------\n    a          : ndarray\n                 Input Numpy array.\n    grad       : bool\n                 Set if the converted array requires gradient.\n\n    Returns\n    ----------\n    c          : torch.Tensor\n                 Converted array.\n    \"\"\"\n    b = np.copy(a)\n    c = torch.from_numpy(b)\n    c.requires_grad_(grad)\n    return c\n
"},{"location":"odak/tools/#odak.tools.file.check_directory","title":"check_directory(directory)","text":"

Definition to check if a directory exist. If it doesn't exist, this definition will create one.

Parameters:

  • directory \u2013
            Full directory path.\n
Source code in odak/tools/file.py
def check_directory(directory):\n    \"\"\"\n    Definition to check if a directory exist. If it doesn't exist, this definition will create one.\n\n\n    Parameters\n    ----------\n    directory     : str\n                    Full directory path.\n    \"\"\"\n    if not os.path.exists(expanduser(directory)):\n        os.makedirs(expanduser(directory))\n        return False\n    return True\n
"},{"location":"odak/tools/#odak.tools.file.convert_bytes","title":"convert_bytes(num)","text":"

A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.

Parameters:

  • num \u2013
         Size in bytes\n

Returns:

  • num ( float ) \u2013

    Size in new unit.

  • x ( str ) \u2013

    New unit bytes, KB, MB, GB or TB.

Source code in odak/tools/file.py
def convert_bytes(num):\n    \"\"\"\n    A definition to convert bytes to semantic scheme (MB,GB or alike). Inspired from https://stackoverflow.com/questions/2104080/how-can-i-check-file-size-in-python#2104083.\n\n\n    Parameters\n    ----------\n    num        : float\n                 Size in bytes\n\n\n    Returns\n    ----------\n    num        : float\n                 Size in new unit.\n    x          : str\n                 New unit bytes, KB, MB, GB or TB.\n    \"\"\"\n    for x in ['bytes', 'KB', 'MB', 'GB', 'TB']:\n        if num < 1024.0:\n            return num, x\n        num /= 1024.0\n    return None, None\n
"},{"location":"odak/tools/#odak.tools.file.copy_file","title":"copy_file(source, destination, follow_symlinks=True)","text":"

Definition to copy a file from one location to another.

Parameters:

  • source \u2013
              Source filename.\n
  • destination \u2013
              Destination filename.\n
  • follow_symlinks (bool, default: True ) \u2013
              Set to True to follow the source of symbolic links.\n
Source code in odak/tools/file.py
def copy_file(source, destination, follow_symlinks = True):\n    \"\"\"\n    Definition to copy a file from one location to another.\n\n\n\n    Parameters\n    ----------\n    source          : str\n                      Source filename.\n    destination     : str\n                      Destination filename.\n    follow_symlinks : bool\n                      Set to True to follow the source of symbolic links.\n    \"\"\"\n    return shutil.copyfile(\n                           expanduser(source),\n                           expanduser(source),\n                           follow_symlinks = follow_symlinks\n                          )\n
"},{"location":"odak/tools/#odak.tools.file.expanduser","title":"expanduser(filename)","text":"

Definition to decode filename using namespaces and shortcuts.

Parameters:

  • filename \u2013
            Filename.\n

Returns:

  • new_filename ( str ) \u2013

    Filename.

Source code in odak/tools/file.py
def expanduser(filename):\n    \"\"\"\n    Definition to decode filename using namespaces and shortcuts.\n\n\n    Parameters\n    ----------\n    filename      : str\n                    Filename.\n\n\n    Returns\n    -------\n    new_filename  : str\n                    Filename.\n    \"\"\"\n    new_filename = os.path.expanduser(filename)\n    return new_filename\n
"},{"location":"odak/tools/#odak.tools.file.list_files","title":"list_files(path, key='*.*', recursive=True)","text":"

Definition to list files in a given path with a given key.

Parameters:

  • path \u2013
          Path to a folder.\n
  • key \u2013
          Key used for scanning a path.\n
  • recursive \u2013
          If set True, scan the path recursively.\n

Returns:

  • files_list ( ndarray ) \u2013

    list of files found in a given path.

Source code in odak/tools/file.py
def list_files(path, key = '*.*', recursive = True):\n    \"\"\"\n    Definition to list files in a given path with a given key.\n\n\n    Parameters\n    ----------\n    path        : str\n                  Path to a folder.\n    key         : str\n                  Key used for scanning a path.\n    recursive   : bool\n                  If set True, scan the path recursively.\n\n\n    Returns\n    ----------\n    files_list  : ndarray\n                  list of files found in a given path.\n    \"\"\"\n    if recursive == True:\n        search_result = pathlib.Path(expanduser(path)).rglob(key)\n    elif recursive == False:\n        search_result = pathlib.Path(expanduser(path)).glob(key)\n    files_list = []\n    for item in search_result:\n        files_list.append(str(item))\n    files_list = sorted(files_list)\n    return files_list\n
"},{"location":"odak/tools/#odak.tools.file.load_dictionary","title":"load_dictionary(filename)","text":"

Definition to load a dictionary (JSON) file.

Parameters:

  • filename \u2013
            Filename.\n

Returns:

  • settings ( dict ) \u2013

    Dictionary read from the file.

Source code in odak/tools/file.py
def load_dictionary(filename):\n    \"\"\"\n    Definition to load a dictionary (JSON) file.\n\n\n    Parameters\n    ----------\n    filename      : str\n                    Filename.\n\n\n    Returns\n    ----------\n    settings      : dict\n                    Dictionary read from the file.\n\n    \"\"\"\n    settings = json.load(open(expanduser(filename)))\n    return settings\n
"},{"location":"odak/tools/#odak.tools.file.load_image","title":"load_image(fn, normalizeby=0.0, torch_style=False)","text":"

Definition to load an image from a given location as a Numpy array.

Parameters:

  • fn \u2013
           Filename.\n
  • normalizeby \u2013
           Value to to normalize images with. Default value of zero will lead to no normalization.\n
  • torch_style \u2013
           If set True, it will load an image mxnx3 as 3xmxn.\n

Returns:

  • image ( ndarray ) \u2013

    Image loaded as a Numpy array.

Source code in odak/tools/file.py
def load_image(fn, normalizeby = 0., torch_style = False):\n    \"\"\" \n    Definition to load an image from a given location as a Numpy array.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    normalizeby  : float\n                   Value to to normalize images with. Default value of zero will lead to no normalization.\n    torch_style  : bool\n                   If set True, it will load an image mxnx3 as 3xmxn.\n\n\n    Returns\n    ----------\n    image        :  ndarray\n                    Image loaded as a Numpy array.\n\n    \"\"\"\n    image = cv2.imread(expanduser(fn), cv2.IMREAD_UNCHANGED)\n    if isinstance(image, type(None)):\n         logging.warning('Image not properly loaded. Check filename or image type.')    \n         sys.exit()\n    if len(image.shape) > 2:\n        new_image = np.copy(image)\n        new_image[:, :, 0] = image[:, :, 2]\n        new_image[:, :, 2] = image[:, :, 0]\n        image = new_image\n    if normalizeby != 0.:\n        image = image * 1. / normalizeby\n    if torch_style == True and len(image.shape) > 2:\n        image = np.moveaxis(image, -1, 0)\n    return image.astype(float)\n
"},{"location":"odak/tools/#odak.tools.file.read_text_file","title":"read_text_file(filename)","text":"

Definition to read a given text file and convert it into a Pythonic list.

Parameters:

  • filename \u2013
              Source filename (i.e. test.txt).\n

Returns:

  • content ( list ) \u2013

    Pythonic string list containing the text from the file provided.

Source code in odak/tools/file.py
def read_text_file(filename):\n    \"\"\"\n    Definition to read a given text file and convert it into a Pythonic list.\n\n\n    Parameters\n    ----------\n    filename        : str\n                      Source filename (i.e. test.txt).\n\n\n    Returns\n    -------\n    content         : list\n                      Pythonic string list containing the text from the file provided.\n    \"\"\"\n    content = []\n    loaded_file = open(expanduser(filename))\n    while line := loaded_file.readline():\n        content.append(line.rstrip())\n    return content\n
"},{"location":"odak/tools/#odak.tools.file.resize_image","title":"resize_image(img, target_size)","text":"

Definition to resize a given image to a target shape.

Parameters:

  • img \u2013
            MxN image to be resized.\n        Image must be normalized (0-1).\n
  • target_size \u2013
            Target shape.\n

Returns:

  • img ( ndarray ) \u2013

    Resized image.

Source code in odak/tools/file.py
def resize_image(img, target_size):\n    \"\"\"\n    Definition to resize a given image to a target shape.\n\n\n    Parameters\n    ----------\n    img           : ndarray\n                    MxN image to be resized.\n                    Image must be normalized (0-1).\n    target_size   : list\n                    Target shape.\n\n\n    Returns\n    ----------\n    img           : ndarray\n                    Resized image.\n\n    \"\"\"\n    img = cv2.resize(img, dsize=(target_size[0], target_size[1]), interpolation=cv2.INTER_AREA)\n    return img\n
"},{"location":"odak/tools/#odak.tools.file.save_dictionary","title":"save_dictionary(settings, filename)","text":"

Definition to load a dictionary (JSON) file.

Parameters:

  • settings \u2013
            Dictionary read from the file.\n
  • filename \u2013
            Filename.\n
Source code in odak/tools/file.py
def save_dictionary(settings, filename):\n    \"\"\"\n    Definition to load a dictionary (JSON) file.\n\n\n    Parameters\n    ----------\n    settings      : dict\n                    Dictionary read from the file.\n    filename      : str\n                    Filename.\n    \"\"\"\n    with open(expanduser(filename), 'w', encoding='utf-8') as f:\n        json.dump(settings, f, ensure_ascii=False, indent=4)\n    return settings\n
"},{"location":"odak/tools/#odak.tools.file.save_image","title":"save_image(fn, img, cmin=0, cmax=255, color_depth=8)","text":"

Definition to save a Numpy array as an image.

Parameters:

  • fn \u2013
           Filename.\n
  • img \u2013
           A numpy array with NxMx3 or NxMx1 shapes.\n
  • cmin \u2013
           Minimum value that will be interpreted as 0 level in the final image.\n
  • cmax \u2013
           Maximum value that will be interpreted as 255 level in the final image.\n
  • color_depth \u2013
           Pixel color depth in bits, default is eight bits.\n

Returns:

  • bool ( bool ) \u2013

    True if successful.

Source code in odak/tools/file.py
def save_image(fn, img, cmin = 0, cmax = 255, color_depth = 8):\n    \"\"\"\n    Definition to save a Numpy array as an image.\n\n\n    Parameters\n    ----------\n    fn           : str\n                   Filename.\n    img          : ndarray\n                   A numpy array with NxMx3 or NxMx1 shapes.\n    cmin         : int\n                   Minimum value that will be interpreted as 0 level in the final image.\n    cmax         : int\n                   Maximum value that will be interpreted as 255 level in the final image.\n    color_depth  : int\n                   Pixel color depth in bits, default is eight bits.\n\n\n    Returns\n    ----------\n    bool         :  bool\n                    True if successful.\n\n    \"\"\"\n    input_img = np.copy(img).astype(np.float32)\n    cmin = float(cmin)\n    cmax = float(cmax)\n    input_img[input_img < cmin] = cmin\n    input_img[input_img > cmax] = cmax\n    input_img /= cmax\n    input_img = input_img * 1. * (2**color_depth - 1)\n    if color_depth == 8:\n        input_img = input_img.astype(np.uint8)\n    elif color_depth == 16:\n        input_img = input_img.astype(np.uint16)\n    if len(input_img.shape) > 2:\n        if input_img.shape[2] > 1:\n            cache_img = np.copy(input_img)\n            cache_img[:, :, 0] = input_img[:, :, 2]\n            cache_img[:, :, 2] = input_img[:, :, 0]\n            input_img = cache_img\n    cv2.imwrite(expanduser(fn), input_img)\n    return True\n
"},{"location":"odak/tools/#odak.tools.file.shell_command","title":"shell_command(cmd, cwd='.', timeout=None, check=True)","text":"

Definition to initiate shell commands.

Parameters:

  • cmd \u2013
           Command to be executed.\n
  • cwd \u2013
           Working directory.\n
  • timeout \u2013
           Timeout if the process isn't complete in the given number of seconds.\n
  • check \u2013
           Set it to True to return the results and to enable timeout.\n

Returns:

  • proc ( Popen ) \u2013

    Generated process.

  • outs ( str ) \u2013

    Outputs of the executed command, returns None when check is set to False.

  • errs ( str ) \u2013

    Errors of the executed command, returns None when check is set to False.

Source code in odak/tools/file.py
def shell_command(cmd, cwd = '.', timeout = None, check = True):\n    \"\"\"\n    Definition to initiate shell commands.\n\n\n    Parameters\n    ----------\n    cmd          : list\n                   Command to be executed. \n    cwd          : str\n                   Working directory.\n    timeout      : int\n                   Timeout if the process isn't complete in the given number of seconds.\n    check        : bool\n                   Set it to True to return the results and to enable timeout.\n\n\n    Returns\n    ----------\n    proc         : subprocess.Popen\n                   Generated process.\n    outs         : str\n                   Outputs of the executed command, returns None when check is set to False.\n    errs         : str\n                   Errors of the executed command, returns None when check is set to False.\n\n    \"\"\"\n    for item_id in range(len(cmd)):\n        cmd[item_id] = expanduser(cmd[item_id])\n    proc = subprocess.Popen(\n                            cmd,\n                            cwd = cwd,\n                            stdout = subprocess.PIPE\n                           )\n    if check == False:\n        return proc, None, None\n    try:\n        outs, errs = proc.communicate(timeout = timeout)\n    except subprocess.TimeoutExpired:\n        proc.kill()\n        outs, errs = proc.communicate()\n    return proc, outs, errs\n
"},{"location":"odak/tools/#odak.tools.file.size_of_a_file","title":"size_of_a_file(file_path)","text":"

A definition to get size of a file with a relevant unit.

Parameters:

  • file_path \u2013
         Path of the file.\n

Returns:

  • a ( float ) \u2013

    Size of the file.

  • b ( str ) \u2013

    Unit of the size (bytes, KB, MB, GB or TB).

Source code in odak/tools/file.py
def size_of_a_file(file_path):\n    \"\"\"\n    A definition to get size of a file with a relevant unit.\n\n\n    Parameters\n    ----------\n    file_path  : float\n                 Path of the file.\n\n\n    Returns\n    ----------\n    a          : float\n                 Size of the file.\n    b          : str\n                 Unit of the size (bytes, KB, MB, GB or TB).\n    \"\"\"\n    if os.path.isfile(file_path):\n        file_info = os.stat(file_path)\n        a, b = convert_bytes(file_info.st_size)\n        return a, b\n    return None, None\n
"},{"location":"odak/tools/#odak.tools.file.write_to_text_file","title":"write_to_text_file(content, filename, write_flag='w')","text":"

Defininition to write a Pythonic list to a text file.

Parameters:

  • content \u2013
              Pythonic string list to be written to a file.\n
  • filename \u2013
              Destination filename (i.e. test.txt).\n
  • write_flag \u2013
              Defines the interaction with the file. \n          The default is \"w\" (overwrite any existing content).\n          For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files\n
Source code in odak/tools/file.py
def write_to_text_file(content, filename, write_flag = 'w'):\n    \"\"\"\n    Defininition to write a Pythonic list to a text file.\n\n\n    Parameters\n    ----------\n    content         : list\n                      Pythonic string list to be written to a file.\n    filename        : str\n                      Destination filename (i.e. test.txt).\n    write_flag      : str\n                      Defines the interaction with the file. \n                      The default is \"w\" (overwrite any existing content).\n                      For more see: https://docs.python.org/3/tutorial/inputoutput.html#reading-and-writing-files\n    \"\"\"\n    with open(expanduser(filename), write_flag) as f:\n        for line in content:\n            f.write('{}\\n'.format(line))\n    return True\n
"},{"location":"odak/tools/#odak.tools.latex.__init__","title":"__init__(filename)","text":"

Parameters:

  • filename \u2013
           Source filename (i.e. sample.tex).\n
Source code in odak/tools/latex.py
def __init__(\n             self,\n             filename\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    filename     : str\n                   Source filename (i.e. sample.tex).\n    \"\"\"\n    self.filename = filename\n    self.content = read_text_file(self.filename)\n    self.content_type = []\n    self.latex_dictionary = [\n                             '\\\\documentclass',\n                             '\\\\if',\n                             '\\\\pdf',\n                             '\\\\else',\n                             '\\\\fi',\n                             '\\\\vgtc',\n                             '\\\\teaser',\n                             '\\\\abstract',\n                             '\\\\CCS',\n                             '\\\\usepackage',\n                             '\\\\PassOptionsToPackage',\n                             '\\\\definecolor',\n                             '\\\\AtBeginDocument',\n                             '\\\\providecommand',\n                             '\\\\setcopyright',\n                             '\\\\copyrightyear',\n                             '\\\\acmYear',\n                             '\\\\citestyle',\n                             '\\\\newcommand',\n                             '\\\\acmDOI',\n                             '\\\\newabbreviation',\n                             '\\\\global',\n                             '\\\\begin{document}',\n                             '\\\\author',\n                             '\\\\affiliation',\n                             '\\\\email',\n                             '\\\\institution',\n                             '\\\\streetaddress',\n                             '\\\\city',\n                             '\\\\country',\n                             '\\\\postcode',\n                             '\\\\ccsdesc',\n                             '\\\\received',\n                             '\\\\includegraphics',\n                             '\\\\caption',\n                             '\\\\centering',\n                             '\\\\label',\n                             '\\\\maketitle',\n                             '\\\\toprule',\n                             '\\\\multirow',\n                             '\\\\multicolumn',\n                             '\\\\cmidrule',\n                             '\\\\addlinespace',\n                             '\\\\midrule',\n                             '\\\\cellcolor',\n                             '\\\\bibliography',\n                             '}',\n                             '\\\\title',\n                             '</ccs2012>',\n                             '\\\\bottomrule',\n                             '<concept>',\n                             '<concept',\n                             '<ccs',\n                             '\\\\item',\n                             '</concept',\n                             '\\\\begin{abstract}',\n                             '\\\\end{abstract}',\n                             '\\\\endinput',\n                             '\\\\\\\\'\n                            ]\n    self.latex_begin_dictionary = [\n                                   '\\\\begin{figure}',\n                                   '\\\\begin{figure*}',\n                                   '\\\\begin{equation}',\n                                   '\\\\begin{CCSXML}',\n                                   '\\\\begin{teaserfigure}',\n                                   '\\\\begin{table*}',\n                                   '\\\\begin{table}',\n                                   '\\\\begin{gather}',\n                                   '\\\\begin{align}',\n                                  ]\n    self.latex_end_dictionary = [\n                                 '\\\\end{figure}',\n                                 '\\\\end{figure*}',\n                                 '\\\\end{equation}',\n                                 '\\\\end{CCSXML}',\n                                 '\\\\end{teaserfigure}',\n                                 '\\\\end{table*}',\n                                 '\\\\end{table}',\n                                 '\\\\end{gather}',\n                                 '\\\\end{align}',\n                                ]\n    self._label_lines()\n
"},{"location":"odak/tools/#odak.tools.latex.get_line","title":"get_line(line_id=0)","text":"

Definition to get a specific line by inputting a line nunber.

Returns:

  • line ( str ) \u2013

    Requested line.

  • content_type ( str ) \u2013

    Line's content type (e.g., latex, comment, text).

Source code in odak/tools/latex.py
def get_line(self, line_id = 0):\n    \"\"\"\n    Definition to get a specific line by inputting a line nunber.\n\n\n    Returns\n    ----------\n    line           : str\n                     Requested line.\n    content_type   : str\n                     Line's content type (e.g., latex, comment, text).\n    \"\"\"\n    line = self.content[line_id]\n    content_type = self.content_type[line_id]\n    return line, content_type\n
"},{"location":"odak/tools/#odak.tools.latex.get_line_count","title":"get_line_count()","text":"

Definition to get the line count.

Returns:

  • line_count ( int ) \u2013

    Number of lines in the loaded latex document.

Source code in odak/tools/latex.py
def get_line_count(self):\n    \"\"\"\n    Definition to get the line count.\n\n\n    Returns\n    -------\n    line_count     : int\n                     Number of lines in the loaded latex document.\n    \"\"\"\n    self.line_count = len(self.content)\n    return self.line_count\n
"},{"location":"odak/tools/#odak.tools.latex.set_latex_dictonaries","title":"set_latex_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary)","text":"

Set document specific dictionaries so that the lines could be labelled in accordance.

Parameters:

  • begin_dictionary \u2013
                   Pythonic list containing latex syntax for begin commands (i.e. \\begin{align}).\n
  • end_dictionary \u2013
                   Pythonic list containing latex syntax for end commands (i.e. \\end{table}).\n
  • syntax_dictionary \u2013
                   Pythonic list containing latex syntax (i.e. \\item).\n
Source code in odak/tools/latex.py
def set_latex_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n    \"\"\"\n    Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n    Parameters\n    ----------\n    begin_dictionary     : list\n                           Pythonic list containing latex syntax for begin commands (i.e. \\\\begin{align}).\n    end_dictionary       : list\n                           Pythonic list containing latex syntax for end commands (i.e. \\\\end{table}).\n    syntax_dictionary    : list\n                           Pythonic list containing latex syntax (i.e. \\\\item).\n\n    \"\"\"\n    self.latex_begin_dictionary = begin_dictionary\n    self.latex_end_dictionary = end_dictionary\n    self.latex_dictionary = syntax_dictionary\n    self._label_lines\n
"},{"location":"odak/tools/#odak.tools.matrix.blur_gaussian","title":"blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3])","text":"

A definition to blur a field using a Gaussian kernel.

Parameters:

  • field \u2013
            MxN field.\n
  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n

Returns:

  • blurred_field ( ndarray ) \u2013

    Blurred field.

Source code in odak/tools/matrix.py
def blur_gaussian(field, kernel_length=[21, 21], nsigma=[3, 3]):\n    \"\"\"\n    A definition to blur a field using a Gaussian kernel.\n\n    Parameters\n    ----------\n    field         : ndarray\n                    MxN field.\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n\n    Returns\n    ----------\n    blurred_field : ndarray\n                    Blurred field.\n    \"\"\"\n    kernel = generate_2d_gaussian(kernel_length, nsigma)\n    kernel = zero_pad(kernel, field.shape)\n    blurred_field = convolve2d(field, kernel)\n    blurred_field = blurred_field/np.amax(blurred_field)\n    return blurred_field\n
"},{"location":"odak/tools/#odak.tools.matrix.convolve2d","title":"convolve2d(field, kernel)","text":"

Definition to convolve a field with a kernel by multiplying in frequency space.

Parameters:

  • field \u2013
          Input field with MxN shape.\n
  • kernel \u2013
          Input kernel with MxN shape.\n

Returns:

  • new_field ( ndarray ) \u2013

    Convolved field.

Source code in odak/tools/matrix.py
def convolve2d(field, kernel):\n    \"\"\"\n    Definition to convolve a field with a kernel by multiplying in frequency space.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field with MxN shape.\n    kernel      : ndarray\n                  Input kernel with MxN shape.\n\n    Returns\n    ----------\n    new_field   : ndarray\n                  Convolved field.\n    \"\"\"\n    fr = np.fft.fft2(field)\n    fr2 = np.fft.fft2(np.flipud(np.fliplr(kernel)))\n    m, n = fr.shape\n    new_field = np.real(np.fft.ifft2(fr*fr2))\n    new_field = np.roll(new_field, int(-m/2+1), axis=0)\n    new_field = np.roll(new_field, int(-n/2+1), axis=1)\n    return new_field\n
"},{"location":"odak/tools/#odak.tools.matrix.create_empty_list","title":"create_empty_list(dimensions=[1, 1])","text":"

A definition to create an empty Pythonic list.

Parameters:

  • dimensions \u2013
           Dimensions of the list to be created.\n

Returns:

  • new_list ( list ) \u2013

    New empty list.

Source code in odak/tools/matrix.py
def create_empty_list(dimensions = [1, 1]):\n    \"\"\"\n    A definition to create an empty Pythonic list.\n\n    Parameters\n    ----------\n    dimensions   : list\n                   Dimensions of the list to be created.\n\n    Returns\n    -------\n    new_list     : list\n                   New empty list.\n    \"\"\"\n    new_list = 0\n    for n in reversed(dimensions):\n        new_list = [new_list] * n\n    return new_list\n
"},{"location":"odak/tools/#odak.tools.matrix.crop_center","title":"crop_center(field, size=None)","text":"

Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.

Parameters:

  • field \u2013
          Input field 2Mx2N array.\n

Returns:

  • cropped ( ndarray ) \u2013

    Cropped version of the input field.

Source code in odak/tools/matrix.py
def crop_center(field, size=None):\n    \"\"\"\n    Definition to crop the center of a field with 2Mx2N size. The outcome is a MxN array.\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field 2Mx2N array.\n\n    Returns\n    ----------\n    cropped     : ndarray\n                  Cropped version of the input field.\n    \"\"\"\n    if type(size) == type(None):\n        qx = int(np.ceil(field.shape[0])/4)\n        qy = int(np.ceil(field.shape[1])/4)\n        cropped = np.copy(field[qx:3*qx, qy:3*qy])\n    else:\n        cx = int(np.ceil(field.shape[0]/2))\n        cy = int(np.ceil(field.shape[1]/2))\n        hx = int(np.ceil(size[0]/2))\n        hy = int(np.ceil(size[1]/2))\n        cropped = np.copy(field[cx-hx:cx+hx, cy-hy:cy+hy])\n    return cropped\n
"},{"location":"odak/tools/#odak.tools.matrix.generate_2d_gaussian","title":"generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3])","text":"

Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy

Parameters:

  • kernel_length (list, default: [21, 21] ) \u2013
            Length of the Gaussian kernel along X and Y axes.\n
  • nsigma \u2013
            Sigma of the Gaussian kernel along X and Y axes.\n

Returns:

  • kernel_2d ( ndarray ) \u2013

    Generated Gaussian kernel.

Source code in odak/tools/matrix.py
def generate_2d_gaussian(kernel_length=[21, 21], nsigma=[3, 3]):\n    \"\"\"\n    Generate 2D Gaussian kernel. Inspired from https://stackoverflow.com/questions/29731726/how-to-calculate-a-gaussian-kernel-matrix-efficiently-in-numpy\n\n    Parameters\n    ----------\n    kernel_length : list\n                    Length of the Gaussian kernel along X and Y axes.\n    nsigma        : list\n                    Sigma of the Gaussian kernel along X and Y axes.\n\n    Returns\n    ----------\n    kernel_2d     : ndarray\n                    Generated Gaussian kernel.\n    \"\"\"\n    x = np.linspace(-nsigma[0], nsigma[0], kernel_length[0]+1)\n    y = np.linspace(-nsigma[1], nsigma[1], kernel_length[1]+1)\n    xx, yy = np.meshgrid(x, y)\n    kernel_2d = np.exp(-0.5*(np.square(xx) /\n                       np.square(nsigma[0]) + np.square(yy)/np.square(nsigma[1])))\n    kernel_2d = kernel_2d/kernel_2d.sum()\n    return kernel_2d\n
"},{"location":"odak/tools/#odak.tools.matrix.generate_bandlimits","title":"generate_bandlimits(size=[512, 512], levels=9)","text":"

A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.

Parameters:

  • size \u2013
         Size of each mask in octaves.\n

Returns:

  • masks ( ndarray ) \u2013

    Masks (Octaves).

Source code in odak/tools/matrix.py
def generate_bandlimits(size=[512, 512], levels=9):\n    \"\"\"\n    A definition to calculate octaves used in bandlimiting frequencies in the frequency domain.\n\n    Parameters\n    ----------\n    size       : list\n                 Size of each mask in octaves.\n\n    Returns\n    ----------\n    masks      : ndarray\n                 Masks (Octaves).\n    \"\"\"\n    masks = np.zeros((levels, size[0], size[1]))\n    cx = int(size[0]/2)\n    cy = int(size[1]/2)\n    for i in range(0, masks.shape[0]):\n        deltax = int((size[0])/(2**(i+1)))\n        deltay = int((size[1])/(2**(i+1)))\n        masks[\n            i,\n            cx-deltax:cx+deltax,\n            cy-deltay:cy+deltay\n        ] = 1.\n        masks[\n            i,\n            int(cx-deltax/2.):int(cx+deltax/2.),\n            int(cy-deltay/2.):int(cy+deltay/2.)\n        ] = 0.\n    masks = np.asarray(masks)\n    return masks\n
"},{"location":"odak/tools/#odak.tools.matrix.nufft2","title":"nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Size.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    Inverse NUFFT of the input field.

Source code in odak/tools/matrix.py
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list\n                  Size.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  Inverse NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    result = finufft.nufft2d2(\n        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)\n    if type(size) == type(None):\n        result = result.reshape(field.shape)\n    else:\n        result = result.reshape(size)\n    return result\n
"},{"location":"odak/tools/#odak.tools.matrix.nuifft2","title":"nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Shape of the NUFFT calculated for an input field.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    NUFFT of the input field.

Source code in odak/tools/matrix.py
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list or ndarray\n                  Shape of the NUFFT calculated for an input field.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    if type(size) == type(None):\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            image.shape,\n            eps=eps,\n            isign=sign\n        )\n    else:\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            (size[0], size[1]),\n            eps=eps,\n            isign=sign\n        )\n    result = np.asarray(result)\n    return result\n
"},{"location":"odak/tools/#odak.tools.matrix.quantize","title":"quantize(image_field, bits=4)","text":"

Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.

Parameters:

  • image_field (ndarray) \u2013
          Input image field.\n
  • bits \u2013
          A value in between 0 to 8. Can not be zero.\n

Returns:

  • new_field ( ndarray ) \u2013

    Quantized image field.

Source code in odak/tools/matrix.py
def quantize(image_field, bits=4):\n    \"\"\"\n    Definitio to quantize a image field (0-255, 8 bit) to a certain bits level.\n\n    Parameters\n    ----------\n    image_field : ndarray\n                  Input image field.\n    bits        : int\n                  A value in between 0 to 8. Can not be zero.\n\n    Returns\n    ----------\n    new_field   : ndarray\n                  Quantized image field.\n    \"\"\"\n    divider = 2**(8-bits)\n    new_field = image_field/divider\n    new_field = new_field.astype(np.int64)\n    return new_field\n
"},{"location":"odak/tools/#odak.tools.matrix.zero_pad","title":"zero_pad(field, size=None, method='center')","text":"

Definition to zero pad a MxN array to 2Mx2N array.

Parameters:

  • field \u2013
                Input field MxN array.\n
  • size \u2013
                Size to be zeropadded.\n
  • method \u2013
                Zeropad either by placing the content to center or to the left.\n

Returns:

  • field_zero_padded ( ndarray ) \u2013

    Zeropadded version of the input field.

Source code in odak/tools/matrix.py
def zero_pad(field, size=None, method='center'):\n    \"\"\"\n    Definition to zero pad a MxN array to 2Mx2N array.\n\n    Parameters\n    ----------\n    field             : ndarray\n                        Input field MxN array.\n    size              : list\n                        Size to be zeropadded.\n    method            : str\n                        Zeropad either by placing the content to center or to the left.\n\n    Returns\n    ----------\n    field_zero_padded : ndarray\n                        Zeropadded version of the input field.\n    \"\"\"\n    if type(size) == type(None):\n        hx = int(np.ceil(field.shape[0])/2)\n        hy = int(np.ceil(field.shape[1])/2)\n    else:\n        hx = int(np.ceil((size[0]-field.shape[0])/2))\n        hy = int(np.ceil((size[1]-field.shape[1])/2))\n    if method == 'center':\n        field_zero_padded = np.pad(\n            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))\n    elif method == 'left aligned':\n        field_zero_padded = np.pad(\n            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))\n    if type(size) != type(None):\n        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]\n    return field_zero_padded\n
"},{"location":"odak/tools/#odak.tools.markdown.__init__","title":"__init__(filename)","text":"

Parameters:

  • filename \u2013
           Source filename (i.e. sample.md).\n
Source code in odak/tools/markdown.py
def __init__(\n             self,\n             filename\n            ):\n    \"\"\"\n    Parameters\n    ----------\n    filename     : str\n                   Source filename (i.e. sample.md).\n    \"\"\"\n    self.filename = filename\n    self.content = read_text_file(self.filename)\n    self.content_type = []\n    self.markdown_dictionary = [\n                                 '#',\n                               ]\n    self.markdown_begin_dictionary = [\n                                      '```bash',\n                                      '```python',\n                                      '```',\n                                     ]\n    self.markdown_end_dictionary = [\n                                    '```',\n                                   ]\n    self._label_lines()\n
"},{"location":"odak/tools/#odak.tools.markdown.get_line","title":"get_line(line_id=0)","text":"

Definition to get a specific line by inputting a line nunber.

Returns:

  • line ( str ) \u2013

    Requested line.

  • content_type ( str ) \u2013

    Line's content type (e.g., markdown, comment, text).

Source code in odak/tools/markdown.py
def get_line(self, line_id = 0):\n    \"\"\"\n    Definition to get a specific line by inputting a line nunber.\n\n\n    Returns\n    ----------\n    line           : str\n                     Requested line.\n    content_type   : str\n                     Line's content type (e.g., markdown, comment, text).\n    \"\"\"\n    line = self.content[line_id]\n    content_type = self.content_type[line_id]\n    return line, content_type\n
"},{"location":"odak/tools/#odak.tools.markdown.get_line_count","title":"get_line_count()","text":"

Definition to get the line count.

Returns:

  • line_count ( int ) \u2013

    Number of lines in the loaded markdown document.

Source code in odak/tools/markdown.py
def get_line_count(self):\n    \"\"\"\n    Definition to get the line count.\n\n\n    Returns\n    -------\n    line_count     : int\n                     Number of lines in the loaded markdown document.\n    \"\"\"\n    self.line_count = len(self.content)\n    return self.line_count\n
"},{"location":"odak/tools/#odak.tools.markdown.set_dictonaries","title":"set_dictonaries(begin_dictionary, end_dictionary, syntax_dictionary)","text":"

Set document specific dictionaries so that the lines could be labelled in accordance.

Parameters:

  • begin_dictionary \u2013
                   Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n
  • end_dictionary \u2013
                   Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n
  • syntax_dictionary \u2013
                   Pythonic list containing markdown syntax (i.e. \\item).\n
Source code in odak/tools/markdown.py
def set_dictonaries(self, begin_dictionary, end_dictionary, syntax_dictionary):\n    \"\"\"\n    Set document specific dictionaries so that the lines could be labelled in accordance.\n\n\n    Parameters\n    ----------\n    begin_dictionary     : list\n                           Pythonic list containing markdown syntax for beginning of blocks (e.g., code, html).\n    end_dictionary       : list\n                           Pythonic list containing markdown syntax for end of blocks (e.g., code, html).\n    syntax_dictionary    : list\n                           Pythonic list containing markdown syntax (i.e. \\\\item).\n\n    \"\"\"\n    self.markdown_begin_dictionary = begin_dictionary\n    self.markdown_end_dictionary = end_dictionary\n    self.markdown_dictionary = syntax_dictionary\n    self._label_lines\n
"},{"location":"odak/tools/#odak.tools.sample.batch_of_rays","title":"batch_of_rays(entry, exit)","text":"

Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.

Parameters:

  • entry \u2013
         Either a single point with size of 3 or multiple points with the size of nx3.\n
  • exit \u2013
         Either a single point with size of 3 or multiple points with the size of nx3.\n

Returns:

  • rays ( ndarray ) \u2013

    Generated batch of rays.

Source code in odak/tools/sample.py
def batch_of_rays(entry, exit):\n    \"\"\"\n    Definition to generate a batch of rays with given entry point(s) and exit point(s). Note that the mapping is one to one, meaning nth item in your entry points list will exit from nth item in your exit list and generate that particular ray. Note that you can have a combination like nx3 points for entry or exit and 1 point for entry or exit. But if you have multiple points both for entry and exit, the number of points have to be same both for entry and exit.\n\n    Parameters\n    ----------\n    entry      : ndarray\n                 Either a single point with size of 3 or multiple points with the size of nx3.\n    exit       : ndarray\n                 Either a single point with size of 3 or multiple points with the size of nx3.\n\n    Returns\n    ----------\n    rays       : ndarray\n                 Generated batch of rays.\n    \"\"\"\n    norays = np.array([0, 0])\n    if len(entry.shape) == 1:\n        entry = entry.reshape((1, 3))\n    if len(exit.shape) == 1:\n        exit = exit.reshape((1, 3))\n    norays = np.amax(np.asarray([entry.shape[0], exit.shape[0]]))\n    if norays > exit.shape[0]:\n        exit = np.repeat(exit, norays, axis=0)\n    elif norays > entry.shape[0]:\n        entry = np.repeat(entry, norays, axis=0)\n    rays = []\n    norays = int(norays)\n    for i in range(norays):\n        rays.append(\n            create_ray_from_two_points(\n                entry[i],\n                exit[i]\n            )\n        )\n    rays = np.asarray(rays)\n    return rays\n
"},{"location":"odak/tools/#odak.tools.sample.box_volume_sample","title":"box_volume_sample(no=[10, 10, 10], size=[100.0, 100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples in a box volume.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the volume.\n
  • center \u2013
          Center location of the volume.\n
  • angles \u2013
          Tilt of the volume.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def box_volume_sample(no=[10, 10, 10], size=[100., 100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples in a box volume.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the volume.\n    center      : list\n                  Center location of the volume.\n    angles      : list\n                  Tilt of the volume.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], no[2], 3))\n    x, y, z = np.mgrid[0:no[0], 0:no[1], 0:no[2]]\n    step = [\n        size[0]/no[0],\n        size[1]/no[1],\n        size[2]/no[2]\n    ]\n    samples[:, :, :, 0] = x*step[0]+step[0]/2.-size[0]/2.\n    samples[:, :, :, 1] = y*step[1]+step[1]/2.-size[1]/2.\n    samples[:, :, :, 2] = z*step[2]+step[2]/2.-size[2]/2.\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1]*samples.shape[2], samples.shape[3]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.circular_sample","title":"circular_sample(no=[10, 10], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples inside a circle over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_sample(no=[10, 10], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples inside a circle over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0]+1, no[1]+1, 3))\n    r_angles, r = np.mgrid[0:no[0]+1, 0:no[1]+1]\n    r = r/np.amax(r)*radius\n    r_angles = r_angles/np.amax(r_angles)*np.pi*2\n    samples[:, :, 0] = r*np.cos(r_angles)\n    samples[:, :, 1] = r*np.sin(r_angles)\n    samples = samples[1:no[0]+1, 1:no[1]+1, :]\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1], samples.shape[2]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.circular_uniform_random_sample","title":"circular_uniform_random_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate sample inside a circle uniformly but randomly.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_uniform_random_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\" \n    Definition to generate sample inside a circle uniformly but randomly.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.empty((0, 3))\n    rs = np.sqrt(np.random.uniform(0, 1, no[0]))\n    angs = np.random.uniform(0, 2*np.pi, no[1])\n    for i in rs:\n        for angle in angs:\n            r = radius*i\n            point = np.array(\n                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])\n            samples = np.vstack((samples, point))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.circular_uniform_sample","title":"circular_uniform_sample(no=[10, 50], radius=10.0, center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate sample inside a circle uniformly.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of the circle.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def circular_uniform_sample(no=[10, 50], radius=10., center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate sample inside a circle uniformly.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of the circle.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.empty((0, 3))\n    for i in range(0, no[0]):\n        r = i/no[0]*radius\n        ang_no = no[1]*i/no[0]\n        for j in range(0, int(no[1]*i/no[0])):\n            angle = j/ang_no*2*np.pi\n            point = np.array(\n                [float(r*np.cos(angle)), float(r*np.sin(angle)), 0])\n            samples = np.vstack((samples, point))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.grid_sample","title":"grid_sample(no=[10, 10], size=[100.0, 100.0], center=[0.0, 0.0, 0.0], angles=[0.0, 0.0, 0.0])","text":"

Definition to generate samples over a surface.

Parameters:

  • no \u2013
          Number of samples.\n
  • size \u2013
          Physical size of the surface.\n
  • center \u2013
          Center location of the surface.\n
  • angles \u2013
          Tilt of the surface.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def grid_sample(no=[10, 10], size=[100., 100.], center=[0., 0., 0.], angles=[0., 0., 0.]):\n    \"\"\"\n    Definition to generate samples over a surface.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    size        : list\n                  Physical size of the surface.\n    center      : list\n                  Center location of the surface.\n    angles      : list\n                  Tilt of the surface.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    step = [\n        size[0]/(no[0]-1),\n        size[1]/(no[1]-1)\n    ]\n    x, y = np.mgrid[0:no[0], 0:no[1]]\n    samples[:, :, 0] = x*step[0]-size[0]/2.\n    samples[:, :, 1] = y*step[1]-size[1]/2.\n    samples = samples.reshape(\n        (samples.shape[0]*samples.shape[1], samples.shape[2]))\n    samples = rotate_points(samples, angles=angles, offset=center)\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.random_sample_point_cloud","title":"random_sample_point_cloud(point_cloud, no, p=None)","text":"

Definition to pull a subset of points from a point cloud with a given probability.

Parameters:

  • point_cloud \u2013
           Point cloud array.\n
  • no \u2013
           Number of samples.\n
  • p \u2013
           Probability list in the same size as no.\n

Returns:

  • subset ( ndarray ) \u2013

    Subset of the given point cloud.

Source code in odak/tools/sample.py
def random_sample_point_cloud(point_cloud, no, p=None):\n    \"\"\"\n    Definition to pull a subset of points from a point cloud with a given probability.\n\n    Parameters\n    ----------\n    point_cloud  : ndarray\n                   Point cloud array.\n    no           : list\n                   Number of samples.\n    p            : list\n                   Probability list in the same size as no.\n\n    Returns\n    ----------\n    subset       : ndarray\n                   Subset of the given point cloud.\n    \"\"\"\n    choice = np.random.choice(point_cloud.shape[0], no, p)\n    subset = point_cloud[choice, :]\n    return subset\n
"},{"location":"odak/tools/#odak.tools.sample.sphere_sample","title":"sphere_sample(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2])","text":"

Definition to generate a regular sample set on the surface of a sphere using polar coordinates.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of a sphere.\n
  • center \u2013
          Center of a sphere.\n
  • k \u2013
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def sphere_sample(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):\n    \"\"\"\n    Definition to generate a regular sample set on the surface of a sphere using polar coordinates.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of a sphere.\n    center      : list\n                  Center of a sphere.\n    k           : list\n                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    psi, teta = np.mgrid[0:no[0], 0:no[1]]\n    psi = k[0]*np.pi/no[0]*psi\n    teta = k[1]*np.pi/no[1]*teta\n    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)\n    samples[:, :, 1] = center[0]+radius*np.sin(psi)*np.sin(teta)\n    samples[:, :, 2] = center[0]+radius*np.cos(psi)\n    samples = samples.reshape((no[0]*no[1], 3))\n    return samples\n
"},{"location":"odak/tools/#odak.tools.sample.sphere_sample_uniform","title":"sphere_sample_uniform(no=[10, 10], radius=1.0, center=[0.0, 0.0, 0.0], k=[1, 2])","text":"

Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.

Parameters:

  • no \u2013
          Number of samples.\n
  • radius \u2013
          Radius of a sphere.\n
  • center \u2013
          Center of a sphere.\n
  • k \u2013
          Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n

Returns:

  • samples ( ndarray ) \u2013

    Samples generated.

Source code in odak/tools/sample.py
def sphere_sample_uniform(no=[10, 10], radius=1., center=[0., 0., 0.], k=[1, 2]):\n    \"\"\"\n    Definition to generate an uniform sample set on the surface of a sphere using polar coordinates.\n\n    Parameters\n    ----------\n    no          : list\n                  Number of samples.\n    radius      : float\n                  Radius of a sphere.\n    center      : list\n                  Center of a sphere.\n    k           : list\n                  Multipliers for gathering samples. If you set k=[1,2] it will draw samples from a perfect sphere.\n\n\n    Returns\n    ----------\n    samples     : ndarray\n                  Samples generated.\n\n    \"\"\"\n    samples = np.zeros((no[0], no[1], 3))\n    row = np.arange(0, no[0])\n    psi, teta = np.mgrid[0:no[0], 0:no[1]]\n    for psi_id in range(0, no[0]):\n        psi[psi_id] = np.roll(row, psi_id, axis=0)\n        teta[psi_id] = np.roll(row, -psi_id, axis=0)\n    psi = k[0]*np.pi/no[0]*psi\n    teta = k[1]*np.pi/no[1]*teta\n    samples[:, :, 0] = center[0]+radius*np.sin(psi)*np.cos(teta)\n    samples[:, :, 1] = center[1]+radius*np.sin(psi)*np.sin(teta)\n    samples[:, :, 2] = center[2]+radius*np.cos(psi)\n    samples = samples.reshape((no[0]*no[1], 3))\n    return samples\n
"},{"location":"odak/tools/#odak.tools.vector.closest_point_to_a_ray","title":"closest_point_to_a_ray(point, ray)","text":"

Definition to calculate the point on a ray that is closest to given point.

Parameters:

  • point \u2013
            Given point in X,Y,Z.\n
  • ray \u2013
            Given ray.\n

Returns:

  • closest_point ( ndarray ) \u2013

    Calculated closest point.

Source code in odak/tools/vector.py
def closest_point_to_a_ray(point, ray):\n    \"\"\"\n    Definition to calculate the point on a ray that is closest to given point.\n\n    Parameters\n    ----------\n    point         : list\n                    Given point in X,Y,Z.\n    ray           : ndarray\n                    Given ray.\n\n    Returns\n    ---------\n    closest_point : ndarray\n                    Calculated closest point.\n    \"\"\"\n    from odak.raytracing import propagate_a_ray\n    if len(ray.shape) == 2:\n        ray = ray.reshape((1, 2, 3))\n    p0 = ray[:, 0]\n    p1 = propagate_a_ray(ray, 1.)\n    if len(p1.shape) == 2:\n        p1 = p1.reshape((1, 2, 3))\n    p1 = p1[:, 0]\n    p1 = p1.reshape(3)\n    p0 = p0.reshape(3)\n    point = point.reshape(3)\n    closest_distance = -np.dot((p0-point), (p1-p0))/np.sum((p1-p0)**2)\n    closest_point = propagate_a_ray(ray, closest_distance)[0]\n    return closest_point\n
"},{"location":"odak/tools/#odak.tools.vector.cross_product","title":"cross_product(vector1, vector2)","text":"

Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product

Parameters:

  • vector1 \u2013
           A vector/ray.\n
  • vector2 \u2013
           A vector/ray.\n

Returns:

  • ray ( ndarray ) \u2013

    Array that contains starting points and cosines of a created ray.

Source code in odak/tools/vector.py
def cross_product(vector1, vector2):\n    \"\"\"\n    Definition to cross product two vectors and return the resultant vector. Used method described under: http://en.wikipedia.org/wiki/Cross_product\n\n    Parameters\n    ----------\n    vector1      : ndarray\n                   A vector/ray.\n    vector2      : ndarray\n                   A vector/ray.\n\n    Returns\n    ----------\n    ray          : ndarray\n                   Array that contains starting points and cosines of a created ray.\n    \"\"\"\n    angle = np.cross(vector1[1].T, vector2[1].T)\n    angle = np.asarray(angle)\n    ray = np.array([vector1[0], angle], dtype=np.float32)\n    return ray\n
"},{"location":"odak/tools/#odak.tools.vector.distance_between_point_clouds","title":"distance_between_point_clouds(points0, points1)","text":"

A definition to find distance between every point in one cloud to other points in the other point cloud.

Parameters:

  • points0 \u2013
          Mx3 points.\n
  • points1 \u2013
          Nx3 points.\n

Returns:

  • distances ( ndarray ) \u2013

    MxN distances.

Source code in odak/tools/vector.py
def distance_between_point_clouds(points0, points1):\n    \"\"\"\n    A definition to find distance between every point in one cloud to other points in the other point cloud.\n    Parameters\n    ----------\n    points0     : ndarray\n                  Mx3 points.\n    points1     : ndarray\n                  Nx3 points.\n\n    Returns\n    ----------\n    distances   : ndarray\n                  MxN distances.\n    \"\"\"\n    c = points1.reshape((1, points1.shape[0], points1.shape[1]))\n    a = np.repeat(c, points0.shape[0], axis=0)\n    b = points0.reshape((points0.shape[0], 1, points0.shape[1]))\n    b = np.repeat(b, a.shape[1], axis=1)\n    distances = np.sqrt(np.sum((a-b)**2, axis=2))\n    return distances\n
"},{"location":"odak/tools/#odak.tools.vector.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( float ) \u2013

    Distance in between given two points.

Source code in odak/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : list\n                  First point in X,Y,Z.\n    point2      : list\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : float\n                  Distance in between given two points.\n    \"\"\"\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = np.sqrt(np.sum((point1-point2)**2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))\n    return distance\n
"},{"location":"odak/tools/#odak.tools.vector.point_to_ray_distance","title":"point_to_ray_distance(point, ray_point_0, ray_point_1)","text":"

Definition to find point's closest distance to a line represented with two points.

Parameters:

  • point \u2013
          Point to be tested.\n
  • ray_point_0 (ndarray) \u2013
          First point to represent a line.\n
  • ray_point_1 (ndarray) \u2013
          Second point to represent a line.\n

Returns:

  • distance ( float ) \u2013

    Calculated distance.

Source code in odak/tools/vector.py
def point_to_ray_distance(point, ray_point_0, ray_point_1):\n    \"\"\"\n    Definition to find point's closest distance to a line represented with two points.\n\n    Parameters\n    ----------\n    point       : ndarray\n                  Point to be tested.\n    ray_point_0 : ndarray\n                  First point to represent a line.\n    ray_point_1 : ndarray\n                  Second point to represent a line.\n\n    Returns\n    ----------\n    distance    : float\n                  Calculated distance.\n    \"\"\"\n    distance = np.sum(np.cross((point-ray_point_0), (point-ray_point_1))\n                      ** 2)/np.sum((ray_point_1-ray_point_0)**2)\n    return distance\n
"},{"location":"odak/tools/#odak.tools.vector.same_side","title":"same_side(p1, p2, a, b)","text":"

Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.

Parameters:

  • p1 \u2013
          Point(s) to check.\n
  • p2 \u2013
          This is the point check against.\n
  • a \u2013
          First point that forms the line.\n
  • b \u2013
          Second point that forms the line.\n
Source code in odak/tools/vector.py
def same_side(p1, p2, a, b):\n    \"\"\"\n    Definition to figure which side a point is on with respect to a line and a point. See http://www.blackpawn.com/texts/pointinpoly/ for more. If p1 and p2 are on the sameside, this definition returns True.\n\n    Parameters\n    ----------\n    p1          : list\n                  Point(s) to check.\n    p2          : list\n                  This is the point check against.\n    a           : list\n                  First point that forms the line.\n    b           : list\n                  Second point that forms the line.\n    \"\"\"\n    ba = np.subtract(b, a)\n    p1a = np.subtract(p1, a)\n    p2a = np.subtract(p2, a)\n    cp1 = np.cross(ba, p1a)\n    cp2 = np.cross(ba, p2a)\n    test = np.dot(cp1, cp2)\n    if len(p1.shape) > 1:\n        return test >= 0\n    if test >= 0:\n        return True\n    return False\n
"},{"location":"odak/tools/#odak.tools.transformation.rotate_point","title":"rotate_point(point, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.

Parameters:

  • point \u2013
           A point.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

  • rotx ( ndarray ) \u2013

    Rotation matrix along X axis.

  • roty ( ndarray ) \u2013

    Rotation matrix along Y axis.

  • rotz ( ndarray ) \u2013

    Rotation matrix along Z axis.

Source code in odak/tools/transformation.py
def rotate_point(point, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate a given point. Note that rotation is always with respect to 0,0,0.\n\n    Parameters\n    ----------\n    point        : ndarray\n                   A point.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation\n    rotx         : ndarray\n                   Rotation matrix along X axis.\n    roty         : ndarray\n                   Rotation matrix along Y axis.\n    rotz         : ndarray\n                   Rotation matrix along Z axis.\n    \"\"\"\n    point = np.asarray(point)\n    point -= np.asarray(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, point)))\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, point)))\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, point)))\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, point)))\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, point)))\n    result += np.asarray(origin)\n    result += np.asarray(offset)\n    return result, rotx, roty, rotz\n
"},{"location":"odak/tools/#odak.tools.transformation.rotate_points","title":"rotate_points(points, angles=[0, 0, 0], mode='XYZ', origin=[0, 0, 0], offset=[0, 0, 0])","text":"

Definition to rotate points.

Parameters:

  • points \u2013
           Points.\n
  • angles \u2013
           Rotation angles in degrees.\n
  • mode \u2013
           Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n
  • origin \u2013
           Reference point for a rotation.\n
  • offset \u2013
           Shift with the given offset.\n

Returns:

  • result ( ndarray ) \u2013

    Result of the rotation

Source code in odak/tools/transformation.py
def rotate_points(points, angles = [0, 0, 0], mode = 'XYZ', origin = [0, 0, 0], offset = [0, 0, 0]):\n    \"\"\"\n    Definition to rotate points.\n\n    Parameters\n    ----------\n    points       : ndarray\n                   Points.\n    angles       : list\n                   Rotation angles in degrees. \n    mode         : str\n                   Rotation mode determines ordering of the rotations at each axis. There are XYZ,YXZ,ZXY and ZYX modes.\n    origin       : list\n                   Reference point for a rotation.\n    offset       : list\n                   Shift with the given offset.\n\n    Returns\n    ----------\n    result       : ndarray\n                   Result of the rotation   \n    \"\"\"\n    points = np.asarray(points)\n    if angles[0] == 0 and angles[1] == 0 and angles[2] == 0:\n        result = np.array(offset) + points\n        return result\n    points -= np.array(origin)\n    rotx = rotmatx(angles[0])\n    roty = rotmaty(angles[1])\n    rotz = rotmatz(angles[2])\n    if mode == 'XYZ':\n        result = np.dot(rotz, np.dot(roty, np.dot(rotx, points.T))).T\n    elif mode == 'XZY':\n        result = np.dot(roty, np.dot(rotz, np.dot(rotx, points.T))).T\n    elif mode == 'YXZ':\n        result = np.dot(rotz, np.dot(rotx, np.dot(roty, points.T))).T\n    elif mode == 'ZXY':\n        result = np.dot(roty, np.dot(rotx, np.dot(rotz, points.T))).T\n    elif mode == 'ZYX':\n        result = np.dot(rotx, np.dot(roty, np.dot(rotz, points.T))).T\n    result += np.array(origin)\n    result += np.array(offset)\n    return result\n
"},{"location":"odak/tools/#odak.tools.transformation.rotmatx","title":"rotmatx(angle)","text":"

Definition to generate a rotation matrix along X axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotx ( ndarray ) \u2013

    Rotation matrix along X axis.

Source code in odak/tools/transformation.py
def rotmatx(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along X axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    rotx         : ndarray\n                   Rotation matrix along X axis.\n    \"\"\"\n    angle = np.float64(angle)\n    angle = np.radians(angle)\n    rotx = np.array([\n        [1.,               0.,               0.],\n        [0.,  math.cos(angle), -math.sin(angle)],\n        [0.,  math.sin(angle),  math.cos(angle)]\n    ], dtype=np.float64)\n    return rotx\n
"},{"location":"odak/tools/#odak.tools.transformation.rotmaty","title":"rotmaty(angle)","text":"

Definition to generate a rotation matrix along Y axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • roty ( ndarray ) \u2013

    Rotation matrix along Y axis.

Source code in odak/tools/transformation.py
def rotmaty(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Y axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    roty         : ndarray\n                   Rotation matrix along Y axis.\n    \"\"\"\n    angle = np.radians(angle)\n    roty = np.array([\n        [math.cos(angle),  0., math.sin(angle)],\n        [0.,               1.,              0.],\n        [-math.sin(angle), 0., math.cos(angle)]\n    ], dtype=np.float64)\n    return roty\n
"},{"location":"odak/tools/#odak.tools.transformation.rotmatz","title":"rotmatz(angle)","text":"

Definition to generate a rotation matrix along Z axis.

Parameters:

  • angle \u2013
           Rotation angles in degrees.\n

Returns:

  • rotz ( ndarray ) \u2013

    Rotation matrix along Z axis.

Source code in odak/tools/transformation.py
def rotmatz(angle):\n    \"\"\"\n    Definition to generate a rotation matrix along Z axis.\n\n    Parameters\n    ----------\n    angle        : list\n                   Rotation angles in degrees.\n\n    Returns\n    -------\n    rotz         : ndarray\n                   Rotation matrix along Z axis.\n    \"\"\"\n    angle = np.radians(angle)\n    rotz = np.array([\n        [math.cos(angle), -math.sin(angle), 0.],\n        [math.sin(angle),  math.cos(angle), 0.],\n        [0.,               0., 1.]\n    ], dtype=np.float64)\n\n    return rotz\n
"},{"location":"odak/tools/#odak.tools.transformation.tilt_towards","title":"tilt_towards(location, lookat)","text":"

Definition to tilt surface normal of a plane towards a point.

Parameters:

  • location \u2013
           Center of the plane to be tilted.\n
  • lookat \u2013
           Tilt towards this point.\n

Returns:

  • angles ( list ) \u2013

    Rotation angles in degrees.

Source code in odak/tools/transformation.py
def tilt_towards(location, lookat):\n    \"\"\"\n    Definition to tilt surface normal of a plane towards a point.\n\n    Parameters\n    ----------\n    location     : list\n                   Center of the plane to be tilted.\n    lookat       : list\n                   Tilt towards this point.\n\n    Returns\n    ----------\n    angles       : list\n                   Rotation angles in degrees.\n    \"\"\"\n    dx = location[0]-lookat[0]\n    dy = location[1]-lookat[1]\n    dz = location[2]-lookat[2]\n    dist = np.sqrt(dx**2+dy**2+dz**2)\n    phi = np.arctan2(dy, dx)\n    theta = np.arccos(dz/dist)\n    angles = [\n        0,\n        np.degrees(theta).tolist(),\n        np.degrees(phi).tolist()\n    ]\n    return angles\n
"},{"location":"odak/wave/","title":"odak.wave","text":"

odak.wave

Provides necessary definitions for merging geometric optics with wave theory and classical approaches in the wave theory as well. See \"Introduction to Fourier Optcs\" from Joseph Goodman for the theoratical explanation.

"},{"location":"odak/wave/#odak.wave.adaptive_sampling_angular_spectrum","title":"adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    iflag = -1\n    eps = 10**(-12)\n    nv, nu = field.shape\n    l = nu*dx\n    x = np.linspace(-l/2, l/2, nu)\n    y = np.linspace(-l/2, l/2, nv)\n    X, Y = np.meshgrid(x, y)\n    fx = np.linspace(-1./2./dx, 1./2./dx, nu)\n    fy = np.linspace(-1./2./dx, 1./2./dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    forig = 1./2./dx\n    fc2 = 1./2*(nu/wavelength/np.abs(distance))**0.5\n    ss = np.abs(fc2)/forig\n    zc = nu*dx**2/wavelength\n    K = nu/2/np.amax(np.abs(fx))\n    m = 2\n    nnu2 = m*nu\n    nnv2 = m*nv\n    fxn = np.linspace(-1./2./dx, 1./2./dx, nnu2)\n    fyn = np.linspace(-1./2./dx, 1./2./dx, nnv2)\n    if np.abs(distance) > zc*2:\n        fxn = fxn*ss\n        fyn = fyn*ss\n    FXN, FYN = np.meshgrid(fxn, fyn)\n    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)\n    FX = FXN/np.amax(FXN)*np.pi\n    FY = FYN/np.amax(FYN)*np.pi\n    t_2 = nufft2(field, FX*ss, FY*ss, size=[nnv2, nnu2], sign=iflag, eps=eps)\n    FX = FX/np.amax(FX)*np.pi\n    FY = FY/np.amax(FY)*np.pi\n    result = nuifft2(Hn*t_2, FX*ss, FY*ss, size=[nv, nu], sign=-iflag, eps=eps)\n    return result\n
"},{"location":"odak/wave/#odak.wave.add_phase","title":"add_phase(field, new_phase)","text":"

Definition for adding a phase to a given complex field.

Parameters:

  • field \u2013
           Complex field.\n
  • new_phase \u2013
           Complex phase.\n

Returns:

  • new_field ( complex64 ) \u2013

    Complex field.

Source code in odak/wave/__init__.py
def add_phase(field, new_phase):\n    \"\"\"\n    Definition for adding a phase to a given complex field.\n\n    Parameters\n    ----------\n    field        : np.complex64\n                   Complex field.\n    new_phase    : np.complex64\n                   Complex phase.\n\n    Returns\n    -------\n    new_field    : np.complex64\n                   Complex field.\n    \"\"\"\n    phase = calculate_phase(field)\n    amplitude = calculate_amplitude(field)\n    new_field = amplitude*np.cos(phase+new_phase) + \\\n        1j*amplitude*np.sin(phase+new_phase)\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.add_random_phase","title":"add_random_phase(field)","text":"

Definition for adding a random phase to a given complex field.

Parameters:

  • field \u2013
           Complex field.\n

Returns:

  • new_field ( complex64 ) \u2013

    Complex field.

Source code in odak/wave/__init__.py
def add_random_phase(field):\n    \"\"\"\n    Definition for adding a random phase to a given complex field.\n\n    Parameters\n    ----------\n    field        : np.complex64\n                   Complex field.\n\n    Returns\n    -------\n    new_field    : np.complex64\n                   Complex field.\n    \"\"\"\n    random_phase = np.pi*np.random.random(field.shape)\n    new_field = add_phase(field, random_phase)\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.adjust_phase_only_slm_range","title":"adjust_phase_only_slm_range(native_range, working_wavelength, native_wavelength)","text":"

Definition for calculating the phase range of the Spatial Light Modulator (SLM) for a given wavelength. Here you prove maximum angle as the lower bound is typically zero. If the lower bound isn't zero in angles, you can use this very same definition for calculating lower angular bound as well.

Parameters:

  • native_range \u2013
                 Native range of the phase only SLM in radians (i.e. two pi).\n
  • working_wavelength (float) \u2013
                 Wavelength of the illumination source or some working wavelength.\n
  • native_wavelength \u2013
                 Wavelength which the SLM is designed for.\n

Returns:

  • new_range ( float ) \u2013

    Calculated phase range in radians.

Source code in odak/wave/__init__.py
def adjust_phase_only_slm_range(native_range, working_wavelength, native_wavelength):\n    \"\"\"\n    Definition for calculating the phase range of the Spatial Light Modulator (SLM) for a given wavelength. Here you prove maximum angle as the lower bound is typically zero. If the lower bound isn't zero in angles, you can use this very same definition for calculating lower angular bound as well.\n\n    Parameters\n    ----------\n    native_range       : float\n                         Native range of the phase only SLM in radians (i.e. two pi).\n    working_wavelength : float\n                         Wavelength of the illumination source or some working wavelength.\n    native_wavelength  : float\n                         Wavelength which the SLM is designed for.\n\n    Returns\n    -------\n    new_range          : float\n                         Calculated phase range in radians.\n    \"\"\"\n    new_range = native_range/working_wavelength*native_wavelength\n    return new_range\n
"},{"location":"odak/wave/#odak.wave.angular_spectrum","title":"angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate angular spectrum based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate angular spectrum based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu/2*dx, nu/2*dx, nu)\n    y = np.linspace(-nv/2*dx, nv/2*dx, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))\n    h = np.fft.fft2(np.fft.fftshift(h))*dx**2\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = h*U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    return result\n
"},{"location":"odak/wave/#odak.wave.band_extended_angular_spectrum","title":"band_extended_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics Letters 45.6 (2020): 1543-1546.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def band_extended_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics Letters 45.6 (2020): 1543-1546.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    iflag = -1\n    eps = 10**(-12)\n    nv, nu = field.shape\n    l = nu*dx\n    x = np.linspace(-l/2, l/2, nu)\n    y = np.linspace(-l/2, l/2, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    fx = np.linspace(-1./2./dx, 1./2./dx, nu)\n    fy = np.linspace(-1./2./dx, 1./2./dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    K = nu/2/np.amax(fx)\n    fcn = 1./2*(nu/wavelength/np.abs(distance))**0.5\n    ss = np.abs(fcn)/np.amax(np.abs(fx))\n    zc = nu*dx**2/wavelength\n    if np.abs(distance) < zc:\n        fxn = fx\n        fyn = fy\n    else:\n        fxn = fx*ss\n        fyn = fy*ss\n    FXN, FYN = np.meshgrid(fxn, fyn)\n    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)\n    X = X/np.amax(X)*np.pi\n    Y = Y/np.amax(Y)*np.pi\n    t_asmNUFT = nufft2(field, X*ss, Y*ss, sign=iflag, eps=eps)\n    result = nuifft2(Hn*t_asmNUFT, X*ss, Y*ss, sign=-iflag, eps=eps)\n    return result\n
"},{"location":"odak/wave/#odak.wave.band_limited_angular_spectrum","title":"band_limited_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def band_limited_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu/2*dx, nu/2*dx, nu)\n    y = np.linspace(-nv/2*dx, nv/2*dx, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))\n    h = np.fft.fft2(np.fft.fftshift(h))*dx**2\n    flimx = np.ceil(1/(((2*distance*(1./(nu)))**2+1)**0.5*wavelength))\n    flimy = np.ceil(1/(((2*distance*(1./(nv)))**2+1)**0.5*wavelength))\n    mask = np.zeros((nu, nv), dtype=np.complex64)\n    mask = (np.abs(X) < flimx) & (np.abs(Y) < flimy)\n    mask = set_amplitude(h, mask)\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = mask*U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    return result\n
"},{"location":"odak/wave/#odak.wave.calculate_intensity","title":"calculate_intensity(field)","text":"

Definition to calculate intensity of a single or multiple given electric field(s).

Parameters:

  • field \u2013
           Electric fields or an electric field.\n

Returns:

  • intensity ( float ) \u2013

    Intensity or intensities of electric field(s).

Source code in odak/wave/__init__.py
def calculate_intensity(field):\n    \"\"\"\n    Definition to calculate intensity of a single or multiple given electric field(s).\n\n    Parameters\n    ----------\n    field        : ndarray.complex or complex\n                   Electric fields or an electric field.\n\n    Returns\n    -------\n    intensity    : float\n                   Intensity or intensities of electric field(s).\n    \"\"\"\n    intensity = np.abs(field)**2\n    return intensity\n
"},{"location":"odak/wave/#odak.wave.distance_between_two_points","title":"distance_between_two_points(point1, point2)","text":"

Definition to calculate distance between two given points.

Parameters:

  • point1 \u2013
          First point in X,Y,Z.\n
  • point2 \u2013
          Second point in X,Y,Z.\n

Returns:

  • distance ( float ) \u2013

    Distance in between given two points.

Source code in odak/tools/vector.py
def distance_between_two_points(point1, point2):\n    \"\"\"\n    Definition to calculate distance between two given points.\n\n    Parameters\n    ----------\n    point1      : list\n                  First point in X,Y,Z.\n    point2      : list\n                  Second point in X,Y,Z.\n\n    Returns\n    ----------\n    distance    : float\n                  Distance in between given two points.\n    \"\"\"\n    point1 = np.asarray(point1)\n    point2 = np.asarray(point2)\n    if len(point1.shape) == 1 and len(point2.shape) == 1:\n        distance = np.sqrt(np.sum((point1-point2)**2))\n    elif len(point1.shape) == 2 or len(point2.shape) == 2:\n        distance = np.sqrt(np.sum((point1-point2)**2, axis=1))\n    return distance\n
"},{"location":"odak/wave/#odak.wave.double_convergence","title":"double_convergence(nx, ny, k, r, dx)","text":"

A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. \"Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm.\" Optics express 26.8 (2018): 10140-10151.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • r \u2013
         The distance between location of a light source and an image plane.\n
  • dx \u2013
         Pixel pitch.\n

Returns:

  • function ( ndarray ) \u2013

    Generated phase pattern for a Gerchberg-Saxton method.

Source code in odak/wave/lens.py
def double_convergence(nx, ny, k, r, dx):\n    \"\"\"\n    A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. \"Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm.\" Optics express 26.8 (2018): 10140-10151.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    r          : float\n                 The distance between location of a light source and an image plane.\n    dx         : float\n                 Pixel pitch.\n\n    Returns\n    -------\n    function   : ndarray\n                 Generated phase pattern for a Gerchberg-Saxton method.\n    \"\"\"\n    size = [ny, nx]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    w = np.exp(1j*k*Z/r)\n    return w\n
"},{"location":"odak/wave/#odak.wave.electric_field_per_plane_wave","title":"electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0)","text":"

Definition to return state of a plane wave at a particular distance and time.

Parameters:

  • amplitude \u2013
           Amplitude of a wave.\n
  • opd \u2013
           Optical path difference in mm.\n
  • k \u2013
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n
  • phase \u2013
           Initial phase of a wave.\n
  • w \u2013
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n
  • t \u2013
           Time in seconds.\n

Returns:

  • field ( complex ) \u2013

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

Source code in odak/wave/vector.py
def electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0):\n    \"\"\"\n    Definition to return state of a plane wave at a particular distance and time.\n\n    Parameters\n    ----------\n    amplitude    : float\n                   Amplitude of a wave.\n    opd          : float\n                   Optical path difference in mm.\n    k            : float\n                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n    phase        : float\n                   Initial phase of a wave.\n    w            : float\n                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n    t            : float\n                   Time in seconds.\n\n    Returns\n    -------\n    field        : complex\n                   A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).\n    \"\"\"\n    field = amplitude*np.exp(1j*(-w*t+opd*k+phase))/opd**2\n    return field\n
"},{"location":"odak/wave/#odak.wave.fraunhofer","title":"fraunhofer(field, k, distance, dx, wavelength)","text":"

A definition to calculate Fraunhofer based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate Fraunhofer based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    l = nu*dx\n    l2 = wavelength*distance/dx\n    dx2 = wavelength*distance/l\n    fx = np.linspace(-l2/2., l2/2., nu)\n    fy = np.linspace(-l2/2., l2/2., nv)\n    FX, FY = np.meshgrid(fx, fy)\n    FZ = FX**2+FY**2\n    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \\\n        np.exp(1j*k/(2*distance)*FZ)\n    result = c*np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field)))*dx**2\n    return result\n
"},{"location":"odak/wave/#odak.wave.fraunhofer_equal_size_adjust","title":"fraunhofer_equal_size_adjust(field, distance, dx, wavelength)","text":"

A definition to match the physical size of the original field with the propagated field.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • new_field ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer_equal_size_adjust(field, distance, dx, wavelength):\n    \"\"\"\n    A definition to match the physical size of the original field with the propagated field.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    new_field        : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    l1 = nu*dx\n    l2 = wavelength*distance/dx\n    m = l1/l2\n    px = int(m*nu)\n    py = int(m*nv)\n    nx = int(field.shape[0]/2-px/2)\n    ny = int(field.shape[1]/2-py/2)\n    new_field = np.copy(field[nx:nx+px, ny:ny+py])\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.fraunhofer_inverse","title":"fraunhofer_inverse(field, k, distance, dx, wavelength)","text":"

A definition to calculate Inverse Fraunhofer based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer_inverse(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate Inverse Fraunhofer based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    distance = np.abs(distance)\n    nv, nu = field.shape\n    l = nu*dx\n    l2 = wavelength*distance/dx\n    dx2 = wavelength*distance/l\n    fx = np.linspace(-l2/2., l2/2., nu)\n    fy = np.linspace(-l2/2., l2/2., nv)\n    FX, FY = np.meshgrid(fx, fy)\n    FZ = FX**2+FY**2\n    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \\\n        np.exp(1j*k/(2*distance)*FZ)\n    result = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(field/dx**2/c)))\n    return result\n
"},{"location":"odak/wave/#odak.wave.generate_complex_field","title":"generate_complex_field(amplitude, phase)","text":"

Definition to generate a complex field with a given amplitude and phase.

Parameters:

  • amplitude \u2013
                Amplitude of the field.\n
  • phase \u2013
                Phase of the field.\n

Returns:

  • field ( ndarray ) \u2013

    Complex field.

Source code in odak/wave/__init__.py
def generate_complex_field(amplitude, phase):\n    \"\"\"\n    Definition to generate a complex field with a given amplitude and phase.\n\n    Parameters\n    ----------\n    amplitude         : ndarray\n                        Amplitude of the field.\n    phase             : ndarray\n                        Phase of the field.\n\n    Returns\n    -------\n    field             : ndarray\n                        Complex field.\n    \"\"\"\n    field = amplitude*np.cos(phase)+1j*amplitude*np.sin(phase)\n    return field\n
"},{"location":"odak/wave/#odak.wave.gerchberg_saxton","title":"gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None)","text":"

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • slm_range \u2013
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n
  • initial_phase \u2013
               Phase to be added to the initial value.\n

Returns:

  • hologram ( complex ) \u2013

    Calculated complex hologram.

  • reconstruction ( complex ) \u2013

    Calculated reconstruction using calculated hologram.

Source code in odak/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None):\n    \"\"\"\n    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.\n\n    Parameters\n    ----------\n    field            : np.complex64\n                       Complex field (MxN).\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    slm_range        : float\n                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n    initial_phase    : np.complex64\n                       Phase to be added to the initial value.\n\n    Returns\n    -------\n    hologram         : np.complex\n                       Calculated complex hologram.\n    reconstruction   : np.complex\n                       Calculated reconstruction using calculated hologram. \n    \"\"\"\n    k = wavenumber(wavelength)\n    target = calculate_amplitude(field)\n    hologram = generate_complex_field(np.ones(field.shape), 0)\n    hologram = zero_pad(hologram)\n    if type(initial_phase) == type(None):\n        hologram = add_random_phase(hologram)\n    else:\n        initial_phase = zero_pad(initial_phase)\n        hologram = add_phase(hologram, initial_phase)\n    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]\n    orig_shape = [int(field.shape[0]/2.), int(field.shape[1]/2.)]\n    for i in tqdm(range(n_iterations), leave=False):\n        reconstruction = propagate_beam(\n            hologram, k, distance, dx, wavelength, propagation_type)\n        new_target = calculate_amplitude(reconstruction)\n        new_target[\n            center[0]-orig_shape[0]:center[0]+orig_shape[0],\n            center[1]-orig_shape[1]:center[1]+orig_shape[1]\n        ] = target\n        reconstruction = generate_complex_field(\n            new_target, calculate_phase(reconstruction))\n        hologram = propagate_beam(\n            reconstruction, k, -distance, dx, wavelength, propagation_type)\n        hologram = generate_complex_field(1, calculate_phase(hologram))\n        hologram = hologram[\n            center[0]-orig_shape[0]:center[0]+orig_shape[0],\n            center[1]-orig_shape[1]:center[1]+orig_shape[1],\n        ]\n        hologram = zero_pad(hologram)\n    reconstruction = propagate_beam(\n        hologram, k, distance, dx, wavelength, propagation_type)\n    hologram = hologram[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    reconstruction = reconstruction[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    return hologram, reconstruction\n
"},{"location":"odak/wave/#odak.wave.gerchberg_saxton_3d","title":"gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None)","text":"

Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. \"30.4: Multi\u2010plane holographic display with a uniform 3D Gerchberg\u2010Saxton algorithm.\" SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.

Parameters:

  • fields \u2013
               Complex fields (MxN).\n
  • distances \u2013
               Propagation distances.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • slm_range \u2013
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n
  • initial_phase \u2013
               Phase to be added to the initial value.\n
  • target_type \u2013
               Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. \"Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg\u2013Saxton algorithm.\" Applied optics 54.23 (2015): 6994-7001.\n

Returns:

  • hologram ( complex ) \u2013

    Calculated complex hologram.

Source code in odak/wave/classical.py
def gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None):\n    \"\"\"\n    Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. \"30.4: Multi\u2010plane holographic display with a uniform 3D Gerchberg\u2010Saxton algorithm.\" SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.\n\n    Parameters\n    ----------\n    fields           : np.complex64\n                       Complex fields (MxN).\n    distances        : list\n                       Propagation distances.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    slm_range        : float\n                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n    initial_phase    : np.complex64\n                       Phase to be added to the initial value.\n    target_type      : str\n                       Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. \"Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg\u2013Saxton algorithm.\" Applied optics 54.23 (2015): 6994-7001. \n\n    Returns\n    -------\n    hologram         : np.complex\n                       Calculated complex hologram.\n    \"\"\"\n    k = wavenumber(wavelength)\n    targets = calculate_amplitude(np.asarray(fields)).astype(np.float64)\n    hologram = generate_complex_field(np.ones(targets[0].shape), 0)\n    hologram = zero_pad(hologram)\n    if type(initial_phase) == type(None):\n        hologram = add_random_phase(hologram)\n    else:\n        initial_phase = zero_pad(initial_phase)\n        hologram = add_phase(hologram, initial_phase)\n    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]\n    orig_shape = [int(fields[0].shape[0]/2.), int(fields[0].shape[1]/2.)]\n    holograms = np.zeros(\n        (len(distances), hologram.shape[0], hologram.shape[1]), dtype=np.complex64)\n    for i in tqdm(range(n_iterations), leave=False):\n        for distance_id in tqdm(range(len(distances)), leave=False):\n            distance = distances[distance_id]\n            reconstruction = propagate_beam(\n                hologram, k, distance, dx, wavelength, propagation_type)\n            if target_type == 'double constraint':\n                if type(coefficients) == type(None):\n                    raise Exception(\n                        \"Provide coeeficients of alpha,beta and gamma for double constraint.\")\n                alpha = coefficients[0]\n                beta = coefficients[1]\n                gamma = coefficients[2]\n                target_current = 2*alpha * \\\n                    np.copy(targets[distance_id])-beta * \\\n                    calculate_amplitude(reconstruction)\n                target_current[target_current == 0] = gamma * \\\n                    np.abs(reconstruction[target_current == 0])\n            elif target_type == 'no constraint':\n                target_current = np.abs(targets[distance_id])\n            new_target = calculate_amplitude(reconstruction)\n            new_target[\n                center[0]-orig_shape[0]:center[0]+orig_shape[0],\n                center[1]-orig_shape[1]:center[1]+orig_shape[1]\n            ] = target_current\n            reconstruction = generate_complex_field(\n                new_target, calculate_phase(reconstruction))\n            hologram_layer = propagate_beam(\n                reconstruction, k, -distance, dx, wavelength, propagation_type)\n            hologram_layer = generate_complex_field(\n                1., calculate_phase(hologram_layer))\n            hologram_layer = hologram_layer[\n                center[0]-orig_shape[0]:center[0]+orig_shape[0],\n                center[1]-orig_shape[1]:center[1]+orig_shape[1]\n            ]\n            hologram_layer = zero_pad(hologram_layer)\n            holograms[distance_id] = hologram_layer\n        hologram = np.sum(holograms, axis=0)\n    hologram = hologram[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    return hologram\n
"},{"location":"odak/wave/#odak.wave.impulse_response_fresnel","title":"impulse_response_fresnel(field, k, distance, dx, wavelength)","text":"

A definition to calculate impulse response based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def impulse_response_fresnel(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate impulse response based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu / 2 * dx, nu / 2 * dx, nu)\n    y = np.linspace(-nv / 2 * dx, nv / 2 * dx, nv)\n    X, Y = np.meshgrid(x, y)\n    h = 1. / (1j * wavelength * distance) * np.exp(1j * k / (2 * distance) * (X ** 2 + Y ** 2))\n    H = np.fft.fft2(np.fft.fftshift(h))\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = H * U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    result = np.roll(result, shift = (1, 1), axis = (0, 1))\n    return result\n
"},{"location":"odak/wave/#odak.wave.linear_grating","title":"linear_grating(nx, ny, every=2, add=3.14, axis='x')","text":"

A definition to generate a linear grating.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • every \u2013
         Add the add value at every given number.\n
  • add \u2013
         Angle to be added.\n
  • axis \u2013
         Axis eiter X,Y or both.\n

Returns:

  • field ( ndarray ) \u2013

    Linear grating term.

Source code in odak/wave/lens.py
def linear_grating(nx, ny, every=2, add=3.14, axis='x'):\n    \"\"\"\n    A definition to generate a linear grating.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    every      : int\n                 Add the add value at every given number.\n    add        : float\n                 Angle to be added.\n    axis       : string\n                 Axis eiter X,Y or both.\n\n    Returns\n    -------\n    field      : ndarray\n                 Linear grating term.\n    \"\"\"\n    grating = np.zeros((nx, ny), dtype=np.complex64)\n    if axis == 'x':\n        grating[::every, :] = np.exp(1j*add)\n    if axis == 'y':\n        grating[:, ::every] = np.exp(1j*add)\n    if axis == 'xy':\n        checker = np.indices((nx, ny)).sum(axis=0) % every\n        checker += 1\n        checker = checker % 2\n        grating = np.exp(1j*checker*add)\n    return grating\n
"},{"location":"odak/wave/#odak.wave.nufft2","title":"nufft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Size.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    Inverse NUFFT of the input field.

Source code in odak/tools/matrix.py
def nufft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list\n                  Size.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  Inverse NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nufft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    result = finufft.nufft2d2(\n        fx.flatten(), fy.flatten(), image, eps=eps, isign=sign)\n    if type(size) == type(None):\n        result = result.reshape(field.shape)\n    else:\n        result = result.reshape(size)\n    return result\n
"},{"location":"odak/wave/#odak.wave.nuifft2","title":"nuifft2(field, fx, fy, size=None, sign=1, eps=10 ** -12)","text":"

A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).

Parameters:

  • field \u2013
          Input field.\n
  • fx \u2013
          Frequencies along x axis.\n
  • fy \u2013
          Frequencies along y axis.\n
  • size \u2013
          Shape of the NUFFT calculated for an input field.\n
  • sign \u2013
          Sign of the exponential used in NUFFT kernel.\n
  • eps \u2013
          Accuracy of NUFFT.\n

Returns:

  • result ( ndarray ) \u2013

    NUFFT of the input field.

Source code in odak/tools/matrix.py
def nuifft2(field, fx, fy, size=None, sign=1, eps=10**(-12)):\n    \"\"\"\n    A definition to take 2D Adjoint Non-Uniform Fast Fourier Transform (NUFFT).\n\n    Parameters\n    ----------\n    field       : ndarray\n                  Input field.\n    fx          : ndarray\n                  Frequencies along x axis.\n    fy          : ndarray\n                  Frequencies along y axis.\n    size        : list or ndarray\n                  Shape of the NUFFT calculated for an input field.\n    sign        : float\n                  Sign of the exponential used in NUFFT kernel.\n    eps         : float\n                  Accuracy of NUFFT.\n\n    Returns\n    ----------\n    result      : ndarray\n                  NUFFT of the input field.\n    \"\"\"\n    try:\n        import finufft\n    except:\n        print('odak.tools.nuifft2 requires finufft to be installed: pip install finufft')\n    image = np.copy(field).astype(np.complex128)\n    if type(size) == type(None):\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            image.shape,\n            eps=eps,\n            isign=sign\n        )\n    else:\n        result = finufft.nufft2d1(\n            fx.flatten(),\n            fy.flatten(),\n            image.flatten(),\n            (size[0], size[1]),\n            eps=eps,\n            isign=sign\n        )\n    result = np.asarray(result)\n    return result\n
"},{"location":"odak/wave/#odak.wave.prism_phase_function","title":"prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x')","text":"

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • angle \u2013
         Tilt angle of the prism in degrees.\n
  • dx \u2013
         Pixel pitch.\n
  • axis \u2013
         Axis of the prism.\n

Returns:

  • prism ( ndarray ) \u2013

    Generated phase function for a prism.

Source code in odak/wave/lens.py
def prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x'):\n    \"\"\"\n    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    angle      : float\n                 Tilt angle of the prism in degrees.\n    dx         : float\n                 Pixel pitch.\n    axis       : str\n                 Axis of the prism.\n\n    Returns\n    -------\n    prism      : ndarray\n                 Generated phase function for a prism.\n    \"\"\"\n    angle = np.radians(angle)\n    size = [ny, nx]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])\n    X, Y = np.meshgrid(x, y)\n    if axis == 'y':\n        prism = np.exp(-1j*k*np.sin(angle)*Y)\n    elif axis == 'x':\n        prism = np.exp(-1j*k*np.sin(angle)*X)\n    return prism\n
"},{"location":"odak/wave/#odak.wave.produce_phase_only_slm_pattern","title":"produce_phase_only_slm_pattern(hologram, slm_range, filename=None, bits=8, default_range=6.28, illumination=None)","text":"

Definition for producing a pattern for a phase only Spatial Light Modulator (SLM) using a given field.

Parameters:

  • hologram \u2013
                 Input holographic field.\n
  • slm_range \u2013
                 Range of the phase only SLM in radians for a working wavelength (i.e. two pi). See odak.wave.adjust_phase_only_slm_range() for more.\n
  • filename \u2013
                 Optional variable, if provided the patterns will be save to given location.\n
  • bits \u2013
                 Quantization bits.\n
  • default_range \u2013
                 Default range of phase only SLM.\n
  • illumination \u2013
                 Spatial illumination distribution.\n

Returns:

  • pattern ( complex64 ) \u2013

    Adjusted phase only pattern.

  • hologram_digital ( int ) \u2013

    Digital representation of the hologram.

Source code in odak/wave/__init__.py
def produce_phase_only_slm_pattern(hologram, slm_range, filename=None, bits=8, default_range=6.28, illumination=None):\n    \"\"\"\n    Definition for producing a pattern for a phase only Spatial Light Modulator (SLM) using a given field.\n\n    Parameters\n    ----------\n    hologram           : np.complex64\n                         Input holographic field.\n    slm_range          : float\n                         Range of the phase only SLM in radians for a working wavelength (i.e. two pi). See odak.wave.adjust_phase_only_slm_range() for more.\n    filename           : str\n                         Optional variable, if provided the patterns will be save to given location.\n    bits               : int\n                         Quantization bits.\n    default_range      : float \n                         Default range of phase only SLM.\n    illumination       : np.ndarray\n                         Spatial illumination distribution.\n\n    Returns\n    -------\n    pattern            : np.complex64\n                         Adjusted phase only pattern.\n    hologram_digital   : np.int\n                         Digital representation of the hologram.\n    \"\"\"\n    #hologram_phase   = calculate_phase(hologram) % default_range\n    hologram_phase = calculate_phase(hologram)\n    hologram_phase = hologram_phase % slm_range\n    hologram_phase /= slm_range\n    hologram_phase *= 2**bits\n    hologram_phase = hologram_phase.astype(np.int32)\n    hologram_digital = np.copy(hologram_phase)\n    if type(filename) != type(None):\n        save_image(\n            filename,\n            hologram_phase,\n            cmin=0,\n            cmax=2**bits\n        )\n    hologram_phase = hologram_phase.astype(np.float64)\n    hologram_phase *= slm_range/2**bits\n    if type(illumination) == type(None):\n        A = 1.\n    else:\n        A = illumination\n    return A*np.cos(hologram_phase)+A*1j*np.sin(hologram_phase), hologram_digital\n
"},{"location":"odak/wave/#odak.wave.propagate_beam","title":"propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel')","text":"

Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with \"Computational Fourier Optics\" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see \"Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields\".

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel'):\n    \"\"\"\n    Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with \"Computational Fourier Optics\" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see \"Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields\".\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    if propagation_type == 'Rayleigh-Sommerfeld':\n        result = rayleigh_sommerfeld(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Angular Spectrum':\n        result = angular_spectrum(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Impulse Response Fresnel':\n        result = impulse_response_fresnel(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Bandlimited Angular Spectrum':\n        result = band_limited_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Bandextended Angular Spectrum':\n        result = band_extended_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Adaptive Sampling Angular Spectrum':\n        result = adaptive_sampling_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Transfer Function Fresnel':\n        result = transfer_function_fresnel(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Fraunhofer':\n        result = fraunhofer(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Fraunhofer Inverse':\n        result = fraunhofer_inverse(field, k, distance, dx, wavelength)\n    else:\n        raise Exception(\"Unknown propagation type selected.\")\n    return result\n
"},{"location":"odak/wave/#odak.wave.propagate_field","title":"propagate_field(points0, points1, field0, wave_number, direction=1)","text":"

Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.

Parameters:

  • points0 \u2013
            Start points (i.e. odak.tools.grid_sample).\n
  • points1 \u2013
            End points (ie. odak.tools.grid_sample).\n
  • field0 \u2013
            Field for given starting points.\n
  • wave_number \u2013
            Wave number of a wave, see odak.wave.wavenumber for more.\n
  • direction \u2013
            For propagating in forward direction set as 1, otherwise -1.\n

Returns:

  • field1 ( ndarray ) \u2013

    Field for given end points.

Source code in odak/wave/vector.py
def propagate_field(points0, points1, field0, wave_number, direction=1):\n    \"\"\"\n    Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.\n\n    Parameters\n    ----------\n    points0       : ndarray\n                    Start points (i.e. odak.tools.grid_sample).\n    points1       : ndarray\n                    End points (ie. odak.tools.grid_sample).\n    field0        : ndarray\n                    Field for given starting points.\n    wave_number   : float\n                    Wave number of a wave, see odak.wave.wavenumber for more.\n    direction     : float\n                    For propagating in forward direction set as 1, otherwise -1.\n\n    Returns\n    -------\n    field1        : ndarray\n                    Field for given end points.\n    \"\"\"\n    field1 = np.zeros(points1.shape[0], dtype=np.complex64)\n    for point_id in range(points0.shape[0]):\n        point = points0[point_id]\n        distances = distance_between_two_points(\n            point,\n            points1\n        )\n        field1 += electric_field_per_plane_wave(\n            calculate_amplitude(field0[point_id]),\n            distances*direction,\n            wave_number,\n            phase=calculate_phase(field0[point_id])\n        )\n    return field1\n
"},{"location":"odak/wave/#odak.wave.propagate_plane_waves","title":"propagate_plane_waves(field, opd, k, w=0, t=0)","text":"

Definition to propagate a field representing a plane wave at a particular distance and time.

Parameters:

  • field \u2013
           Complex field.\n
  • opd \u2013
           Optical path difference in mm.\n
  • k \u2013
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n
  • w \u2013
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n
  • t \u2013
           Time in seconds.\n

Returns:

  • new_field ( complex ) \u2013

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

Source code in odak/wave/vector.py
def propagate_plane_waves(field, opd, k, w=0, t=0):\n    \"\"\"\n    Definition to propagate a field representing a plane wave at a particular distance and time.\n\n    Parameters\n    ----------\n    field        : complex\n                   Complex field.\n    opd          : float\n                   Optical path difference in mm.\n    k            : float\n                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n    w            : float\n                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n    t            : float\n                   Time in seconds.\n\n    Returns\n    -------\n    new_field     : complex\n                    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).\n    \"\"\"\n    new_field = field*np.exp(1j*(-w*t+opd*k))/opd**2\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.quadratic_phase_function","title":"quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])","text":"

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • focal \u2013
         Focal length of the quadratic phase function.\n
  • dx \u2013
         Pixel pitch.\n
  • offset \u2013
         Deviation from the center along X and Y axes.\n

Returns:

  • function ( ndarray ) \u2013

    Generated quadratic phase function.

Source code in odak/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):\n    \"\"\" \n    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    focal      : float\n                 Focal length of the quadratic phase function.\n    dx         : float\n                 Pixel pitch.\n    offset     : list\n                 Deviation from the center along X and Y axes.\n\n    Returns\n    -------\n    function   : ndarray\n                 Generated quadratic phase function.\n    \"\"\"\n    size = [nx, ny]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])-offset[1]*dx\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])-offset[0]*dx\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    qwf = np.exp(1j*k*0.5*np.sin(Z/focal))\n    return qwf\n
"},{"location":"odak/wave/#odak.wave.rayleigh_resolution","title":"rayleigh_resolution(diameter, focal=None, wavelength=0.0005)","text":"

Definition to calculate rayleigh resolution limit of a lens with a certain focal length and an aperture. Lens is assumed to be focusing a plane wave at a focal distance.

Parameter

diameter : float Diameter of a lens. focal : float Focal length of a lens, when focal length is provided, spatial resolution is provided at the focal plane. When focal length isn't provided angular resolution is provided. wavelength : float Wavelength of light.

Returns:

  • resolution ( float ) \u2013

    Resolvable angular or spatial spot size, see focal in parameters to know what to expect.

Source code in odak/wave/__init__.py
def rayleigh_resolution(diameter, focal=None, wavelength=0.0005):\n    \"\"\"\n    Definition to calculate rayleigh resolution limit of a lens with a certain focal length and an aperture. Lens is assumed to be focusing a plane wave at a focal distance.\n\n    Parameter\n    ---------\n    diameter    : float\n                  Diameter of a lens.\n    focal       : float\n                  Focal length of a lens, when focal length is provided, spatial resolution is provided at the focal plane. When focal length isn't provided angular resolution is provided.\n    wavelength  : float\n                  Wavelength of light.\n\n    Returns\n    --------\n    resolution  : float\n                  Resolvable angular or spatial spot size, see focal in parameters to know what to expect.\n\n    \"\"\"\n    resolution = 1.22*wavelength/diameter\n    if type(focal) != type(None):\n        resolution *= focal\n    return resolution\n
"},{"location":"odak/wave/#odak.wave.rayleigh_sommerfeld","title":"rayleigh_sommerfeld(field, k, distance, dx, wavelength)","text":"

Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def rayleigh_sommerfeld(field, k, distance, dx, wavelength):\n    \"\"\"\n    Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nv * dx / 2, nv * dx / 2, nv)\n    y = np.linspace(-nu * dx / 2, nu * dx / 2, nu)\n    X, Y = np.meshgrid(x, y)\n    Z = X ** 2 + Y ** 2\n    result = np.zeros(field.shape, dtype=np.complex64)\n    direction = int(distance/np.abs(distance))\n    for i in range(nu):\n        for j in range(nv):\n            if field[i, j] != 0:\n                r01 = np.sqrt(distance ** 2 + (X - X[i, j]) ** 2 + (Y - Y[i, j]) ** 2) * direction\n                cosnr01 = np.cos(distance / r01)\n                result += field[i, j] * np.exp(1j * k * r01) / r01 * cosnr01\n    result *= 1. / (1j * wavelength)\n    return result\n
"},{"location":"odak/wave/#odak.wave.rotationspeed","title":"rotationspeed(wavelength, c=3 * 10 ** 11)","text":"

Definition for calculating rotation speed of a wave (w in A*e^(j(wt+phi))).

Parameters:

  • wavelength \u2013
           Wavelength of a wave in mm.\n
  • c \u2013
           Speed of wave in mm/seconds. Default is the speed of light in the void!\n

Returns:

  • w ( float ) \u2013

    Rotation speed.

Source code in odak/wave/__init__.py
def rotationspeed(wavelength, c=3*10**11):\n    \"\"\"\n    Definition for calculating rotation speed of a wave (w in A*e^(j(wt+phi))).\n\n    Parameters\n    ----------\n    wavelength   : float\n                   Wavelength of a wave in mm.\n    c            : float\n                   Speed of wave in mm/seconds. Default is the speed of light in the void!\n\n    Returns\n    -------\n    w            : float\n                   Rotation speed.\n\n    \"\"\"\n    f = c*wavelength\n    w = 2*np.pi*f\n    return w\n
"},{"location":"odak/wave/#odak.wave.set_amplitude","title":"set_amplitude(field, amplitude)","text":"

Definition to keep phase as is and change the amplitude of a given field.

Parameters:

  • field \u2013
           Complex field.\n
  • amplitude \u2013
           Amplitudes.\n

Returns:

  • new_field ( complex64 ) \u2013

    Complex field.

Source code in odak/wave/__init__.py
def set_amplitude(field, amplitude):\n    \"\"\"\n    Definition to keep phase as is and change the amplitude of a given field.\n\n    Parameters\n    ----------\n    field        : np.complex64\n                   Complex field.\n    amplitude    : np.array or np.complex64\n                   Amplitudes.\n\n    Returns\n    -------\n    new_field    : np.complex64\n                   Complex field.\n    \"\"\"\n    amplitude = calculate_amplitude(amplitude)\n    phase = calculate_phase(field)\n    new_field = amplitude*np.cos(phase)+1j*amplitude*np.sin(phase)\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.transfer_function_fresnel","title":"transfer_function_fresnel(field, k, distance, dx, wavelength)","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    nv, nu = field.shape\n    fx = np.linspace(-1. / 2. /dx, 1. /2. /dx, nu)\n    fy = np.linspace(-1. / 2. /dx, 1. /2. /dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    H = np.exp(1j * k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5)\n    U1 = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(field)))\n    U2 = H * U1\n    result = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(U2)))\n    return result\n
"},{"location":"odak/wave/#odak.wave.wavenumber","title":"wavenumber(wavelength)","text":"

Definition for calculating the wavenumber of a plane wave.

Parameters:

  • wavelength \u2013
           Wavelength of a wave in mm.\n

Returns:

  • k ( float ) \u2013

    Wave number for a given wavelength.

Source code in odak/wave/__init__.py
def wavenumber(wavelength):\n    \"\"\"\n    Definition for calculating the wavenumber of a plane wave.\n\n    Parameters\n    ----------\n    wavelength   : float\n                   Wavelength of a wave in mm.\n\n    Returns\n    -------\n    k            : float\n                   Wave number for a given wavelength.\n    \"\"\"\n    k = 2*np.pi/wavelength\n    return k\n
"},{"location":"odak/wave/#odak.wave.zero_pad","title":"zero_pad(field, size=None, method='center')","text":"

Definition to zero pad a MxN array to 2Mx2N array.

Parameters:

  • field \u2013
                Input field MxN array.\n
  • size \u2013
                Size to be zeropadded.\n
  • method \u2013
                Zeropad either by placing the content to center or to the left.\n

Returns:

  • field_zero_padded ( ndarray ) \u2013

    Zeropadded version of the input field.

Source code in odak/tools/matrix.py
def zero_pad(field, size=None, method='center'):\n    \"\"\"\n    Definition to zero pad a MxN array to 2Mx2N array.\n\n    Parameters\n    ----------\n    field             : ndarray\n                        Input field MxN array.\n    size              : list\n                        Size to be zeropadded.\n    method            : str\n                        Zeropad either by placing the content to center or to the left.\n\n    Returns\n    ----------\n    field_zero_padded : ndarray\n                        Zeropadded version of the input field.\n    \"\"\"\n    if type(size) == type(None):\n        hx = int(np.ceil(field.shape[0])/2)\n        hy = int(np.ceil(field.shape[1])/2)\n    else:\n        hx = int(np.ceil((size[0]-field.shape[0])/2))\n        hy = int(np.ceil((size[1]-field.shape[1])/2))\n    if method == 'center':\n        field_zero_padded = np.pad(\n            field, ([hx, hx], [hy, hy]), constant_values=(0, 0))\n    elif method == 'left aligned':\n        field_zero_padded = np.pad(\n            field, ([0, 2*hx], [0, 2*hy]), constant_values=(0, 0))\n    if type(size) != type(None):\n        field_zero_padded = field_zero_padded[0:size[0], 0:size[1]]\n    return field_zero_padded\n
"},{"location":"odak/wave/#odak.wave.classical.adaptive_sampling_angular_spectrum","title":"adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def adaptive_sampling_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate adaptive sampling angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Adaptive-sampling angular spectrum method with full utilization of space-bandwidth product.\" Optics Letters 45.16 (2020): 4416-4419.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    iflag = -1\n    eps = 10**(-12)\n    nv, nu = field.shape\n    l = nu*dx\n    x = np.linspace(-l/2, l/2, nu)\n    y = np.linspace(-l/2, l/2, nv)\n    X, Y = np.meshgrid(x, y)\n    fx = np.linspace(-1./2./dx, 1./2./dx, nu)\n    fy = np.linspace(-1./2./dx, 1./2./dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    forig = 1./2./dx\n    fc2 = 1./2*(nu/wavelength/np.abs(distance))**0.5\n    ss = np.abs(fc2)/forig\n    zc = nu*dx**2/wavelength\n    K = nu/2/np.amax(np.abs(fx))\n    m = 2\n    nnu2 = m*nu\n    nnv2 = m*nv\n    fxn = np.linspace(-1./2./dx, 1./2./dx, nnu2)\n    fyn = np.linspace(-1./2./dx, 1./2./dx, nnv2)\n    if np.abs(distance) > zc*2:\n        fxn = fxn*ss\n        fyn = fyn*ss\n    FXN, FYN = np.meshgrid(fxn, fyn)\n    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)\n    FX = FXN/np.amax(FXN)*np.pi\n    FY = FYN/np.amax(FYN)*np.pi\n    t_2 = nufft2(field, FX*ss, FY*ss, size=[nnv2, nnu2], sign=iflag, eps=eps)\n    FX = FX/np.amax(FX)*np.pi\n    FY = FY/np.amax(FY)*np.pi\n    result = nuifft2(Hn*t_2, FX*ss, FY*ss, size=[nv, nu], sign=-iflag, eps=eps)\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.angular_spectrum","title":"angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate angular spectrum based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate angular spectrum based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu/2*dx, nu/2*dx, nu)\n    y = np.linspace(-nv/2*dx, nv/2*dx, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))\n    h = np.fft.fft2(np.fft.fftshift(h))*dx**2\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = h*U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.band_extended_angular_spectrum","title":"band_extended_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics Letters 45.6 (2020): 1543-1546.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def band_extended_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate bandextended angular spectrum based beam propagation. For more Zhang, Wenhui, Hao Zhang, and Guofan Jin. \"Band-extended angular spectrum method for accurate diffraction calculation in a wide propagation range.\" Optics Letters 45.6 (2020): 1543-1546.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    iflag = -1\n    eps = 10**(-12)\n    nv, nu = field.shape\n    l = nu*dx\n    x = np.linspace(-l/2, l/2, nu)\n    y = np.linspace(-l/2, l/2, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    fx = np.linspace(-1./2./dx, 1./2./dx, nu)\n    fy = np.linspace(-1./2./dx, 1./2./dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    K = nu/2/np.amax(fx)\n    fcn = 1./2*(nu/wavelength/np.abs(distance))**0.5\n    ss = np.abs(fcn)/np.amax(np.abs(fx))\n    zc = nu*dx**2/wavelength\n    if np.abs(distance) < zc:\n        fxn = fx\n        fyn = fy\n    else:\n        fxn = fx*ss\n        fyn = fy*ss\n    FXN, FYN = np.meshgrid(fxn, fyn)\n    Hn = np.exp(1j*k*distance*(1-(FXN*wavelength)**2-(FYN*wavelength)**2)**0.5)\n    X = X/np.amax(X)*np.pi\n    Y = Y/np.amax(Y)*np.pi\n    t_asmNUFT = nufft2(field, X*ss, Y*ss, sign=iflag, eps=eps)\n    result = nuifft2(Hn*t_asmNUFT, X*ss, Y*ss, sign=-iflag, eps=eps)\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.band_limited_angular_spectrum","title":"band_limited_angular_spectrum(field, k, distance, dx, wavelength)","text":"

A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def band_limited_angular_spectrum(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate bandlimited angular spectrum based beam propagation. For more Matsushima, Kyoji, and Tomoyoshi Shimobaba. \"Band-limited angular spectrum method for numerical simulation of free-space propagation in far and near fields.\" Optics express 17.22 (2009): 19662-19673.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu/2*dx, nu/2*dx, nu)\n    y = np.linspace(-nv/2*dx, nv/2*dx, nv)\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    h = 1./(1j*wavelength*distance)*np.exp(1j*k*(distance+Z/2/distance))\n    h = np.fft.fft2(np.fft.fftshift(h))*dx**2\n    flimx = np.ceil(1/(((2*distance*(1./(nu)))**2+1)**0.5*wavelength))\n    flimy = np.ceil(1/(((2*distance*(1./(nv)))**2+1)**0.5*wavelength))\n    mask = np.zeros((nu, nv), dtype=np.complex64)\n    mask = (np.abs(X) < flimx) & (np.abs(Y) < flimy)\n    mask = set_amplitude(h, mask)\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = mask*U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.fraunhofer","title":"fraunhofer(field, k, distance, dx, wavelength)","text":"

A definition to calculate Fraunhofer based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate Fraunhofer based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    l = nu*dx\n    l2 = wavelength*distance/dx\n    dx2 = wavelength*distance/l\n    fx = np.linspace(-l2/2., l2/2., nu)\n    fy = np.linspace(-l2/2., l2/2., nv)\n    FX, FY = np.meshgrid(fx, fy)\n    FZ = FX**2+FY**2\n    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \\\n        np.exp(1j*k/(2*distance)*FZ)\n    result = c*np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(field)))*dx**2\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.fraunhofer_equal_size_adjust","title":"fraunhofer_equal_size_adjust(field, distance, dx, wavelength)","text":"

A definition to match the physical size of the original field with the propagated field.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • new_field ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer_equal_size_adjust(field, distance, dx, wavelength):\n    \"\"\"\n    A definition to match the physical size of the original field with the propagated field.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    new_field        : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    l1 = nu*dx\n    l2 = wavelength*distance/dx\n    m = l1/l2\n    px = int(m*nu)\n    py = int(m*nv)\n    nx = int(field.shape[0]/2-px/2)\n    ny = int(field.shape[1]/2-py/2)\n    new_field = np.copy(field[nx:nx+px, ny:ny+py])\n    return new_field\n
"},{"location":"odak/wave/#odak.wave.classical.fraunhofer_inverse","title":"fraunhofer_inverse(field, k, distance, dx, wavelength)","text":"

A definition to calculate Inverse Fraunhofer based beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def fraunhofer_inverse(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate Inverse Fraunhofer based beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    distance = np.abs(distance)\n    nv, nu = field.shape\n    l = nu*dx\n    l2 = wavelength*distance/dx\n    dx2 = wavelength*distance/l\n    fx = np.linspace(-l2/2., l2/2., nu)\n    fy = np.linspace(-l2/2., l2/2., nv)\n    FX, FY = np.meshgrid(fx, fy)\n    FZ = FX**2+FY**2\n    c = np.exp(1j*k*distance)/(1j*wavelength*distance) * \\\n        np.exp(1j*k/(2*distance)*FZ)\n    result = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(field/dx**2/c)))\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.gerchberg_saxton","title":"gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None)","text":"

Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • slm_range \u2013
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n
  • initial_phase \u2013
               Phase to be added to the initial value.\n

Returns:

  • hologram ( complex ) \u2013

    Calculated complex hologram.

  • reconstruction ( complex ) \u2013

    Calculated reconstruction using calculated hologram.

Source code in odak/wave/classical.py
def gerchberg_saxton(field, n_iterations, distance, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None):\n    \"\"\"\n    Definition to compute a hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Gerchberg, Ralph W. \"A practical algorithm for the determination of phase from image and diffraction plane pictures.\" Optik 35 (1972): 237-246.\n\n    Parameters\n    ----------\n    field            : np.complex64\n                       Complex field (MxN).\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    slm_range        : float\n                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n    initial_phase    : np.complex64\n                       Phase to be added to the initial value.\n\n    Returns\n    -------\n    hologram         : np.complex\n                       Calculated complex hologram.\n    reconstruction   : np.complex\n                       Calculated reconstruction using calculated hologram. \n    \"\"\"\n    k = wavenumber(wavelength)\n    target = calculate_amplitude(field)\n    hologram = generate_complex_field(np.ones(field.shape), 0)\n    hologram = zero_pad(hologram)\n    if type(initial_phase) == type(None):\n        hologram = add_random_phase(hologram)\n    else:\n        initial_phase = zero_pad(initial_phase)\n        hologram = add_phase(hologram, initial_phase)\n    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]\n    orig_shape = [int(field.shape[0]/2.), int(field.shape[1]/2.)]\n    for i in tqdm(range(n_iterations), leave=False):\n        reconstruction = propagate_beam(\n            hologram, k, distance, dx, wavelength, propagation_type)\n        new_target = calculate_amplitude(reconstruction)\n        new_target[\n            center[0]-orig_shape[0]:center[0]+orig_shape[0],\n            center[1]-orig_shape[1]:center[1]+orig_shape[1]\n        ] = target\n        reconstruction = generate_complex_field(\n            new_target, calculate_phase(reconstruction))\n        hologram = propagate_beam(\n            reconstruction, k, -distance, dx, wavelength, propagation_type)\n        hologram = generate_complex_field(1, calculate_phase(hologram))\n        hologram = hologram[\n            center[0]-orig_shape[0]:center[0]+orig_shape[0],\n            center[1]-orig_shape[1]:center[1]+orig_shape[1],\n        ]\n        hologram = zero_pad(hologram)\n    reconstruction = propagate_beam(\n        hologram, k, distance, dx, wavelength, propagation_type)\n    hologram = hologram[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    reconstruction = reconstruction[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    return hologram, reconstruction\n
"},{"location":"odak/wave/#odak.wave.classical.gerchberg_saxton_3d","title":"gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None)","text":"

Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. \"30.4: Multi\u2010plane holographic display with a uniform 3D Gerchberg\u2010Saxton algorithm.\" SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.

Parameters:

  • fields \u2013
               Complex fields (MxN).\n
  • distances \u2013
               Propagation distances.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • slm_range \u2013
               Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n
  • initial_phase \u2013
               Phase to be added to the initial value.\n
  • target_type \u2013
               Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. \"Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg\u2013Saxton algorithm.\" Applied optics 54.23 (2015): 6994-7001.\n

Returns:

  • hologram ( complex ) \u2013

    Calculated complex hologram.

Source code in odak/wave/classical.py
def gerchberg_saxton_3d(fields, n_iterations, distances, dx, wavelength, slm_range=6.28, propagation_type='IR Fresnel', initial_phase=None, target_type='no constraint', coefficients=None):\n    \"\"\"\n    Definition to compute a multi plane hologram using an iterative method called Gerchberg-Saxton phase retrieval algorithm. For more on the method, see: Zhou, Pengcheng, et al. \"30.4: Multi\u2010plane holographic display with a uniform 3D Gerchberg\u2010Saxton algorithm.\" SID Symposium Digest of Technical Papers. Vol. 46. No. 1. 2015.\n\n    Parameters\n    ----------\n    fields           : np.complex64\n                       Complex fields (MxN).\n    distances        : list\n                       Propagation distances.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    slm_range        : float\n                       Typically this is equal to two pi. See odak.wave.adjust_phase_only_slm_range() for more.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, TR Fresnel, Fraunhofer).\n    initial_phase    : np.complex64\n                       Phase to be added to the initial value.\n    target_type      : str\n                       Target type. `No constraint` targets the input target as is. `Double constraint` follows the idea in this paper, which claims to suppress speckle: Chang, Chenliang, et al. \"Speckle-suppressed phase-only holographic three-dimensional display based on double-constraint Gerchberg\u2013Saxton algorithm.\" Applied optics 54.23 (2015): 6994-7001. \n\n    Returns\n    -------\n    hologram         : np.complex\n                       Calculated complex hologram.\n    \"\"\"\n    k = wavenumber(wavelength)\n    targets = calculate_amplitude(np.asarray(fields)).astype(np.float64)\n    hologram = generate_complex_field(np.ones(targets[0].shape), 0)\n    hologram = zero_pad(hologram)\n    if type(initial_phase) == type(None):\n        hologram = add_random_phase(hologram)\n    else:\n        initial_phase = zero_pad(initial_phase)\n        hologram = add_phase(hologram, initial_phase)\n    center = [int(hologram.shape[0]/2.), int(hologram.shape[1]/2.)]\n    orig_shape = [int(fields[0].shape[0]/2.), int(fields[0].shape[1]/2.)]\n    holograms = np.zeros(\n        (len(distances), hologram.shape[0], hologram.shape[1]), dtype=np.complex64)\n    for i in tqdm(range(n_iterations), leave=False):\n        for distance_id in tqdm(range(len(distances)), leave=False):\n            distance = distances[distance_id]\n            reconstruction = propagate_beam(\n                hologram, k, distance, dx, wavelength, propagation_type)\n            if target_type == 'double constraint':\n                if type(coefficients) == type(None):\n                    raise Exception(\n                        \"Provide coeeficients of alpha,beta and gamma for double constraint.\")\n                alpha = coefficients[0]\n                beta = coefficients[1]\n                gamma = coefficients[2]\n                target_current = 2*alpha * \\\n                    np.copy(targets[distance_id])-beta * \\\n                    calculate_amplitude(reconstruction)\n                target_current[target_current == 0] = gamma * \\\n                    np.abs(reconstruction[target_current == 0])\n            elif target_type == 'no constraint':\n                target_current = np.abs(targets[distance_id])\n            new_target = calculate_amplitude(reconstruction)\n            new_target[\n                center[0]-orig_shape[0]:center[0]+orig_shape[0],\n                center[1]-orig_shape[1]:center[1]+orig_shape[1]\n            ] = target_current\n            reconstruction = generate_complex_field(\n                new_target, calculate_phase(reconstruction))\n            hologram_layer = propagate_beam(\n                reconstruction, k, -distance, dx, wavelength, propagation_type)\n            hologram_layer = generate_complex_field(\n                1., calculate_phase(hologram_layer))\n            hologram_layer = hologram_layer[\n                center[0]-orig_shape[0]:center[0]+orig_shape[0],\n                center[1]-orig_shape[1]:center[1]+orig_shape[1]\n            ]\n            hologram_layer = zero_pad(hologram_layer)\n            holograms[distance_id] = hologram_layer\n        hologram = np.sum(holograms, axis=0)\n    hologram = hologram[\n        center[0]-orig_shape[0]:center[0]+orig_shape[0],\n        center[1]-orig_shape[1]:center[1]+orig_shape[1]\n    ]\n    return hologram\n
"},{"location":"odak/wave/#odak.wave.classical.impulse_response_fresnel","title":"impulse_response_fresnel(field, k, distance, dx, wavelength)","text":"

A definition to calculate impulse response based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def impulse_response_fresnel(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate impulse response based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nu / 2 * dx, nu / 2 * dx, nu)\n    y = np.linspace(-nv / 2 * dx, nv / 2 * dx, nv)\n    X, Y = np.meshgrid(x, y)\n    h = 1. / (1j * wavelength * distance) * np.exp(1j * k / (2 * distance) * (X ** 2 + Y ** 2))\n    H = np.fft.fft2(np.fft.fftshift(h))\n    U1 = np.fft.fft2(np.fft.fftshift(field))\n    U2 = H * U1\n    result = np.fft.ifftshift(np.fft.ifft2(U2))\n    result = np.roll(result, shift = (1, 1), axis = (0, 1))\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.propagate_beam","title":"propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel')","text":"

Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with \"Computational Fourier Optics\" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see \"Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields\".

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n
  • propagation_type (str, default: 'IR Fresnel' ) \u2013
               Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def propagate_beam(field, k, distance, dx, wavelength, propagation_type='IR Fresnel'):\n    \"\"\"\n    Definitions for Fresnel Impulse Response (IR), Angular Spectrum (AS), Bandlimited Angular Spectrum (BAS), Fresnel Transfer Function (TF), Fraunhofer diffraction in accordence with \"Computational Fourier Optics\" by David Vuelz. For more on Bandlimited Fresnel impulse response also known as Bandlimited Angular Spectrum method see \"Band-limited Angular Spectrum Method for Numerical Simulation of Free-Space Propagation in Far and Near Fields\".\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n    propagation_type : str\n                       Type of the propagation (IR Fresnel, Angular Spectrum, Bandlimited Angular Spectrum, TR Fresnel, Fraunhofer).\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    if propagation_type == 'Rayleigh-Sommerfeld':\n        result = rayleigh_sommerfeld(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Angular Spectrum':\n        result = angular_spectrum(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Impulse Response Fresnel':\n        result = impulse_response_fresnel(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Bandlimited Angular Spectrum':\n        result = band_limited_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Bandextended Angular Spectrum':\n        result = band_extended_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Adaptive Sampling Angular Spectrum':\n        result = adaptive_sampling_angular_spectrum(\n            field, k, distance, dx, wavelength)\n    elif propagation_type == 'Transfer Function Fresnel':\n        result = transfer_function_fresnel(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Fraunhofer':\n        result = fraunhofer(field, k, distance, dx, wavelength)\n    elif propagation_type == 'Fraunhofer Inverse':\n        result = fraunhofer_inverse(field, k, distance, dx, wavelength)\n    else:\n        raise Exception(\"Unknown propagation type selected.\")\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.rayleigh_sommerfeld","title":"rayleigh_sommerfeld(field, k, distance, dx, wavelength)","text":"

Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def rayleigh_sommerfeld(field, k, distance, dx, wavelength):\n    \"\"\"\n    Definition to compute beam propagation using Rayleigh-Sommerfeld's diffraction formula (Huygens-Fresnel Principle). For more see Section 3.5.2 in Goodman, Joseph W. Introduction to Fourier optics. Roberts and Company Publishers, 2005.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n    \"\"\"\n    nv, nu = field.shape\n    x = np.linspace(-nv * dx / 2, nv * dx / 2, nv)\n    y = np.linspace(-nu * dx / 2, nu * dx / 2, nu)\n    X, Y = np.meshgrid(x, y)\n    Z = X ** 2 + Y ** 2\n    result = np.zeros(field.shape, dtype=np.complex64)\n    direction = int(distance/np.abs(distance))\n    for i in range(nu):\n        for j in range(nv):\n            if field[i, j] != 0:\n                r01 = np.sqrt(distance ** 2 + (X - X[i, j]) ** 2 + (Y - Y[i, j]) ** 2) * direction\n                cosnr01 = np.cos(distance / r01)\n                result += field[i, j] * np.exp(1j * k * r01) / r01 * cosnr01\n    result *= 1. / (1j * wavelength)\n    return result\n
"},{"location":"odak/wave/#odak.wave.classical.transfer_function_fresnel","title":"transfer_function_fresnel(field, k, distance, dx, wavelength)","text":"

A definition to calculate convolution based Fresnel approximation for beam propagation.

Parameters:

  • field \u2013
               Complex field (MxN).\n
  • k \u2013
               Wave number of a wave, see odak.wave.wavenumber for more.\n
  • distance \u2013
               Propagation distance.\n
  • dx \u2013
               Size of one single pixel in the field grid (in meters).\n
  • wavelength \u2013
               Wavelength of the electric field.\n

Returns:

  • result ( complex ) \u2013

    Final complex field (MxN).

Source code in odak/wave/classical.py
def transfer_function_fresnel(field, k, distance, dx, wavelength):\n    \"\"\"\n    A definition to calculate convolution based Fresnel approximation for beam propagation.\n\n    Parameters\n    ----------\n    field            : np.complex\n                       Complex field (MxN).\n    k                : odak.wave.wavenumber\n                       Wave number of a wave, see odak.wave.wavenumber for more.\n    distance         : float\n                       Propagation distance.\n    dx               : float\n                       Size of one single pixel in the field grid (in meters).\n    wavelength       : float\n                       Wavelength of the electric field.\n\n    Returns\n    -------\n    result           : np.complex\n                       Final complex field (MxN).\n\n    \"\"\"\n    nv, nu = field.shape\n    fx = np.linspace(-1. / 2. /dx, 1. /2. /dx, nu)\n    fy = np.linspace(-1. / 2. /dx, 1. /2. /dx, nv)\n    FX, FY = np.meshgrid(fx, fy)\n    H = np.exp(1j * k * distance * (1 - (FX * wavelength) ** 2 - (FY * wavelength) ** 2) ** 0.5)\n    U1 = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(field)))\n    U2 = H * U1\n    result = np.fft.ifftshift(np.fft.ifft2(np.fft.ifftshift(U2)))\n    return result\n
"},{"location":"odak/wave/#odak.wave.lens.double_convergence","title":"double_convergence(nx, ny, k, r, dx)","text":"

A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. \"Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm.\" Optics express 26.8 (2018): 10140-10151.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • r \u2013
         The distance between location of a light source and an image plane.\n
  • dx \u2013
         Pixel pitch.\n

Returns:

  • function ( ndarray ) \u2013

    Generated phase pattern for a Gerchberg-Saxton method.

Source code in odak/wave/lens.py
def double_convergence(nx, ny, k, r, dx):\n    \"\"\"\n    A definition to generate initial phase for a Gerchberg-Saxton method. For more details consult Sun, Peng, et al. \"Holographic near-eye display system based on double-convergence light Gerchberg-Saxton algorithm.\" Optics express 26.8 (2018): 10140-10151.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    r          : float\n                 The distance between location of a light source and an image plane.\n    dx         : float\n                 Pixel pitch.\n\n    Returns\n    -------\n    function   : ndarray\n                 Generated phase pattern for a Gerchberg-Saxton method.\n    \"\"\"\n    size = [ny, nx]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    w = np.exp(1j*k*Z/r)\n    return w\n
"},{"location":"odak/wave/#odak.wave.lens.linear_grating","title":"linear_grating(nx, ny, every=2, add=3.14, axis='x')","text":"

A definition to generate a linear grating.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • every \u2013
         Add the add value at every given number.\n
  • add \u2013
         Angle to be added.\n
  • axis \u2013
         Axis eiter X,Y or both.\n

Returns:

  • field ( ndarray ) \u2013

    Linear grating term.

Source code in odak/wave/lens.py
def linear_grating(nx, ny, every=2, add=3.14, axis='x'):\n    \"\"\"\n    A definition to generate a linear grating.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    every      : int\n                 Add the add value at every given number.\n    add        : float\n                 Angle to be added.\n    axis       : string\n                 Axis eiter X,Y or both.\n\n    Returns\n    -------\n    field      : ndarray\n                 Linear grating term.\n    \"\"\"\n    grating = np.zeros((nx, ny), dtype=np.complex64)\n    if axis == 'x':\n        grating[::every, :] = np.exp(1j*add)\n    if axis == 'y':\n        grating[:, ::every] = np.exp(1j*add)\n    if axis == 'xy':\n        checker = np.indices((nx, ny)).sum(axis=0) % every\n        checker += 1\n        checker = checker % 2\n        grating = np.exp(1j*checker*add)\n    return grating\n
"},{"location":"odak/wave/#odak.wave.lens.prism_phase_function","title":"prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x')","text":"

A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • angle \u2013
         Tilt angle of the prism in degrees.\n
  • dx \u2013
         Pixel pitch.\n
  • axis \u2013
         Axis of the prism.\n

Returns:

  • prism ( ndarray ) \u2013

    Generated phase function for a prism.

Source code in odak/wave/lens.py
def prism_phase_function(nx, ny, k, angle, dx=0.001, axis='x'):\n    \"\"\"\n    A definition to generate 2D phase function that represents a prism. See Goodman's Introduction to Fourier Optics book for more.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    angle      : float\n                 Tilt angle of the prism in degrees.\n    dx         : float\n                 Pixel pitch.\n    axis       : str\n                 Axis of the prism.\n\n    Returns\n    -------\n    prism      : ndarray\n                 Generated phase function for a prism.\n    \"\"\"\n    angle = np.radians(angle)\n    size = [ny, nx]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])\n    X, Y = np.meshgrid(x, y)\n    if axis == 'y':\n        prism = np.exp(-1j*k*np.sin(angle)*Y)\n    elif axis == 'x':\n        prism = np.exp(-1j*k*np.sin(angle)*X)\n    return prism\n
"},{"location":"odak/wave/#odak.wave.lens.quadratic_phase_function","title":"quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0])","text":"

A definition to generate 2D quadratic phase function, which is typically use to represent lenses.

Parameters:

  • nx \u2013
         Size of the output along X.\n
  • ny \u2013
         Size of the output along Y.\n
  • k \u2013
         See odak.wave.wavenumber for more.\n
  • focal \u2013
         Focal length of the quadratic phase function.\n
  • dx \u2013
         Pixel pitch.\n
  • offset \u2013
         Deviation from the center along X and Y axes.\n

Returns:

  • function ( ndarray ) \u2013

    Generated quadratic phase function.

Source code in odak/wave/lens.py
def quadratic_phase_function(nx, ny, k, focal=0.4, dx=0.001, offset=[0, 0]):\n    \"\"\" \n    A definition to generate 2D quadratic phase function, which is typically use to represent lenses.\n\n    Parameters\n    ----------\n    nx         : int\n                 Size of the output along X.\n    ny         : int\n                 Size of the output along Y.\n    k          : odak.wave.wavenumber\n                 See odak.wave.wavenumber for more.\n    focal      : float\n                 Focal length of the quadratic phase function.\n    dx         : float\n                 Pixel pitch.\n    offset     : list\n                 Deviation from the center along X and Y axes.\n\n    Returns\n    -------\n    function   : ndarray\n                 Generated quadratic phase function.\n    \"\"\"\n    size = [nx, ny]\n    x = np.linspace(-size[0]*dx/2, size[0]*dx/2, size[0])-offset[1]*dx\n    y = np.linspace(-size[1]*dx/2, size[1]*dx/2, size[1])-offset[0]*dx\n    X, Y = np.meshgrid(x, y)\n    Z = X**2+Y**2\n    qwf = np.exp(1j*k*0.5*np.sin(Z/focal))\n    return qwf\n
"},{"location":"odak/wave/#odak.wave.utils.calculate_amplitude","title":"calculate_amplitude(field)","text":"

Definition to calculate amplitude of a single or multiple given electric field(s).

Parameters:

  • field \u2013
           Electric fields or an electric field.\n

Returns:

  • amplitude ( float ) \u2013

    Amplitude or amplitudes of electric field(s).

Source code in odak/wave/utils.py
def calculate_amplitude(field):\n    \"\"\" \n    Definition to calculate amplitude of a single or multiple given electric field(s).\n\n    Parameters\n    ----------\n    field        : ndarray.complex or complex\n                   Electric fields or an electric field.\n\n    Returns\n    -------\n    amplitude    : float\n                   Amplitude or amplitudes of electric field(s).\n    \"\"\"\n    amplitude = np.abs(field)\n    return amplitude\n
"},{"location":"odak/wave/#odak.wave.utils.calculate_phase","title":"calculate_phase(field, deg=False)","text":"

Definition to calculate phase of a single or multiple given electric field(s).

Parameters:

  • field \u2013
           Electric fields or an electric field.\n
  • deg \u2013
           If set True, the angles will be returned in degrees.\n

Returns:

  • phase ( float ) \u2013

    Phase or phases of electric field(s) in radians.

Source code in odak/wave/utils.py
def calculate_phase(field, deg=False):\n    \"\"\" \n    Definition to calculate phase of a single or multiple given electric field(s).\n\n    Parameters\n    ----------\n    field        : ndarray.complex or complex\n                   Electric fields or an electric field.\n    deg          : bool\n                   If set True, the angles will be returned in degrees.\n\n    Returns\n    -------\n    phase        : float\n                   Phase or phases of electric field(s) in radians.\n    \"\"\"\n    phase = np.angle(field)\n    if deg == True:\n        phase *= 180./np.pi\n    return phase\n
"},{"location":"odak/wave/#odak.wave.vector.electric_field_per_plane_wave","title":"electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0)","text":"

Definition to return state of a plane wave at a particular distance and time.

Parameters:

  • amplitude \u2013
           Amplitude of a wave.\n
  • opd \u2013
           Optical path difference in mm.\n
  • k \u2013
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n
  • phase \u2013
           Initial phase of a wave.\n
  • w \u2013
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n
  • t \u2013
           Time in seconds.\n

Returns:

  • field ( complex ) \u2013

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

Source code in odak/wave/vector.py
def electric_field_per_plane_wave(amplitude, opd, k, phase=0, w=0, t=0):\n    \"\"\"\n    Definition to return state of a plane wave at a particular distance and time.\n\n    Parameters\n    ----------\n    amplitude    : float\n                   Amplitude of a wave.\n    opd          : float\n                   Optical path difference in mm.\n    k            : float\n                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n    phase        : float\n                   Initial phase of a wave.\n    w            : float\n                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n    t            : float\n                   Time in seconds.\n\n    Returns\n    -------\n    field        : complex\n                   A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).\n    \"\"\"\n    field = amplitude*np.exp(1j*(-w*t+opd*k+phase))/opd**2\n    return field\n
"},{"location":"odak/wave/#odak.wave.vector.propagate_field","title":"propagate_field(points0, points1, field0, wave_number, direction=1)","text":"

Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.

Parameters:

  • points0 \u2013
            Start points (i.e. odak.tools.grid_sample).\n
  • points1 \u2013
            End points (ie. odak.tools.grid_sample).\n
  • field0 \u2013
            Field for given starting points.\n
  • wave_number \u2013
            Wave number of a wave, see odak.wave.wavenumber for more.\n
  • direction \u2013
            For propagating in forward direction set as 1, otherwise -1.\n

Returns:

  • field1 ( ndarray ) \u2013

    Field for given end points.

Source code in odak/wave/vector.py
def propagate_field(points0, points1, field0, wave_number, direction=1):\n    \"\"\"\n    Definition to propagate a field from points to an another points in space: propagate a given array of spherical sources to given set of points in space.\n\n    Parameters\n    ----------\n    points0       : ndarray\n                    Start points (i.e. odak.tools.grid_sample).\n    points1       : ndarray\n                    End points (ie. odak.tools.grid_sample).\n    field0        : ndarray\n                    Field for given starting points.\n    wave_number   : float\n                    Wave number of a wave, see odak.wave.wavenumber for more.\n    direction     : float\n                    For propagating in forward direction set as 1, otherwise -1.\n\n    Returns\n    -------\n    field1        : ndarray\n                    Field for given end points.\n    \"\"\"\n    field1 = np.zeros(points1.shape[0], dtype=np.complex64)\n    for point_id in range(points0.shape[0]):\n        point = points0[point_id]\n        distances = distance_between_two_points(\n            point,\n            points1\n        )\n        field1 += electric_field_per_plane_wave(\n            calculate_amplitude(field0[point_id]),\n            distances*direction,\n            wave_number,\n            phase=calculate_phase(field0[point_id])\n        )\n    return field1\n
"},{"location":"odak/wave/#odak.wave.vector.propagate_plane_waves","title":"propagate_plane_waves(field, opd, k, w=0, t=0)","text":"

Definition to propagate a field representing a plane wave at a particular distance and time.

Parameters:

  • field \u2013
           Complex field.\n
  • opd \u2013
           Optical path difference in mm.\n
  • k \u2013
           Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n
  • w \u2013
           Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n
  • t \u2013
           Time in seconds.\n

Returns:

  • new_field ( complex ) \u2013

    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).

Source code in odak/wave/vector.py
def propagate_plane_waves(field, opd, k, w=0, t=0):\n    \"\"\"\n    Definition to propagate a field representing a plane wave at a particular distance and time.\n\n    Parameters\n    ----------\n    field        : complex\n                   Complex field.\n    opd          : float\n                   Optical path difference in mm.\n    k            : float\n                   Wave number of a wave, see odak.wave.parameters.wavenumber for more.\n    w            : float\n                   Rotation speed of a wave, see odak.wave.parameters.rotationspeed for more.\n    t            : float\n                   Time in seconds.\n\n    Returns\n    -------\n    new_field     : complex\n                    A complex number that provides the resultant field in the complex form A*e^(j(wt+phi)).\n    \"\"\"\n    new_field = field*np.exp(1j*(-w*t+opd*k))/opd**2\n    return new_field\n
"}]} \ No newline at end of file diff --git a/sitemap.xml b/sitemap.xml new file mode 100644 index 00000000..0bae990d --- /dev/null +++ b/sitemap.xml @@ -0,0 +1,131 @@ + + + + https://kaanaksit.github.io/odak/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/beginning/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/cgh/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/contributing/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/installation/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/lensless/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/machine_learning/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/perception/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/raytracing/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/toolkit/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/computational_displays/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/computational_imaging/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/computational_light/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/computer_generated_holography/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/fundamentals/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/geometric_optics/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/photonic_computers/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/course/visual_perception/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/notes/holographic_light_transport/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/notes/optimizing_holograms_using_odak/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/notes/using_metameric_loss/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/fit/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_lensless/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_models/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_perception/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_raytracing/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_tools/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/learn_wave/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/raytracing/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/tools/ + 2024-11-05 + + + https://kaanaksit.github.io/odak/odak/wave/ + 2024-11-05 + + \ No newline at end of file diff --git a/sitemap.xml.gz b/sitemap.xml.gz new file mode 100644 index 00000000..96fce852 Binary files /dev/null and b/sitemap.xml.gz differ diff --git a/stylesheets/extra.css b/stylesheets/extra.css new file mode 100644 index 00000000..8b0ed6ea --- /dev/null +++ b/stylesheets/extra.css @@ -0,0 +1,16 @@ +[data-md-color-scheme="odak"] { + --md-primary-fg-color: #00577D; + --md-primary-fg-color--light: #00577D; + --md-primary-fg-color--dark: #00577D; +} + +[data-md-color-scheme="slate"] { + --md-hue: 210; + --md-primary-fg-color: #00577D; +} + +.md-grid { + max-width: 1440px; + + +} diff --git a/toolkit/index.html b/toolkit/index.html new file mode 100644 index 00000000..672b20b6 --- /dev/null +++ b/toolkit/index.html @@ -0,0 +1,1768 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Introduction - Odak + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + +
+ + + + + + +
+ + + + +
+ +
+ + + + +
+
+ + + +
+
+
+ + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

General toolkit.

+

Odak provides a set of functions that can be used for general purpose work, such as saving an image file or loading a three-dimensional point cloud of an object. +These functions are helpful for general use and provide consistency across routine works in loading and saving routines. +When working with odak, we strongly suggest sticking to the general toolkit to provide a coherent solution to your task.

+

Engineering notes

+ + + + + + + + + + + + + + + + + +
NoteDescription
Working with imagesThis engineering note will give you an idea about how read and write images using odak.
Working with dictionariesThis engineering note will give you an idea about how read and write dictionaries using odak.
+ + + + + + + + + + + + + +
+
+ + + + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + + + + + + + \ No newline at end of file
+Computational Displays +

Computational Displays topic deals with inventing next-generation display technology for the future of human-computer interaction. +Common examples of emerging Computational Displays are near-eye displays such as Virtual Reality headsets and Augmented Reality Glasses. +Today, we all use displays as a core component for any visual task, such as working, entertainment, education, and many more.

+
+