Skip to content

Commit

Permalink
Merge pull request #213 from NathanielF/feature_instrumental_variables
Browse files Browse the repository at this point in the history
Add Bayesian instrumental variable estimation
  • Loading branch information
drbenvincent authored Aug 24, 2023
2 parents 31e0039 + cb68c78 commit 3070bc9
Show file tree
Hide file tree
Showing 14 changed files with 1,239 additions and 3 deletions.
65 changes: 65 additions & 0 deletions causalpy/data/AJR2001.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
longname,shortnam,logmort0,risk,loggdp,campaign,source0,slave,latitude,neoeuro,asia,africa,other,edes1975,campaignsj,campaignsj2,mortnaval1,logmortnaval1,mortnaval2,logmortnaval2,mortjam,logmortjam,logmortcap250,logmortjam250,wandcafrica,malfal94,wacacontested,mortnaval2250,logmortnaval2250,mortnaval1250,logmortnaval1250
Angola,AGO,5.6347895,5.3600001,7.77,1,0,0,0.1367,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Argentina,ARG,4.232656,6.3899999,9.1300001,1,0,0,0.37779999,0,0,0,0,90.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
Australia,AUS,2.1459312,9.3199997,9.8999996,0,0,0,0.30000001,1,0,0,1,99.0,0,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
Burkina Faso,BFA,5.6347895,4.4499998,6.8499999,1,0,0,0.1444,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Bangladesh,BGD,4.2684379,5.1399999,6.8800001,1,1,0,0.2667,0,1,0,0,0.0,1,1,71.410004,4.2684379,71.410004,4.2684379,71.410004,4.2684379,4.2684379,4.2684379,0,0.12008,0,71.410004,4.2684379,71.410004,4.2684379
Bahamas,BHS,4.4426513,7.5,9.29,0,0,0,0.2683,0,0,0,0,10.0,0,0,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,,0,85.0,4.4426513,85.0,4.4426513
Bolivia,BOL,4.2626801,5.6399999,7.9299998,1,0,0,0.18889999,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.00165,0,93.25,4.535284,,
Brazil,BRA,4.2626801,7.9099998,8.7299995,1,0,0,0.1111,0,0,0,0,55.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.035999998,0,30.5,3.4177268,15.07,2.7127061
Canada,CAN,2.7788193,9.7299995,9.9899998,0,1,0,0.66670001,1,0,0,0,98.0,0,0,16.1,2.7788193,16.1,2.7788193,16.1,2.7788193,2.7788193,2.7788193,0,0.0,0,16.1,2.7788193,16.1,2.7788193
Chile,CHL,4.232656,7.8200002,9.3400002,1,0,0,0.33329999,0,0,0,0,50.0,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.232656,4.0342407,0,0.0,0,30.5,3.4177268,15.07,2.7127061
Cote d'Ivoire,CIV,6.5042882,7.0,7.4400001,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Cameroon,CMR,5.6347895,6.4499998,7.5,1,0,0,0.066699997,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Congo,COG,5.480639,4.6799998,7.4200001,0,1,1,0.0111,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,0,240.0,5.480639,240.0,5.480639
Colombia,COL,4.2626801,7.3200002,8.8100004,1,0,0,0.044399999,0,0,0,0,25.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.14637001,0,93.25,4.535284,,
Costa Rica,CRI,4.3579903,7.0500002,8.79,1,0,0,0.1111,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Dominican Re,DOM,4.8675346,6.1799998,8.3599997,0,0,0,0.2111,0,0,0,0,25.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
Algeria,DZA,4.3592696,6.5,8.3900003,1,1,0,0.31110001,0,0,1,0,0.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
Ecuador,ECU,4.2626801,6.5500002,8.4700003,1,0,0,0.0222,0,0,0,0,30.000002,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.11894999,0,93.25,4.535284,,
Egypt,EGY,4.2165623,6.77,7.9499998,1,1,0,0.30000001,0,0,1,0,0.0,1,1,67.800003,4.2165623,67.800003,4.2165623,67.800003,4.2165623,4.2165623,4.2165623,0,0.0,0,67.800003,4.2165623,67.800003,4.2165623
Ethiopia,ETH,3.2580965,5.73,6.1100001,1,1,0,0.0889,0,0,1,0,0.0,1,1,26.0,3.2580965,26.0,3.2580965,26.0,3.2580965,3.2580965,3.2580965,1,0.551,0,26.0,3.2580965,26.0,3.2580965
Gabon,GAB,5.6347895,7.8200002,8.8999996,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
Ghana,GHA,6.5042882,6.27,7.3699999,1,1,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Guinea,GIN,6.1800165,6.5500002,7.4899998,1,0,0,0.1222,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Gambia,GMB,7.2930179,8.2700005,7.27,1,1,0,0.1476,0,0,1,0,0.0,1,1,1470.0,7.2930179,1470.0,7.2930179,1470.0,7.2930179,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Guatemala,GTM,4.2626801,5.1399999,8.29,1,0,0,0.17,0,0,0,0,20.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0036000002,0,93.25,4.535284,,
Guyana,GUY,3.4713452,5.8899999,7.9000001,0,0,0,0.055599999,0,0,0,0,2.0,0,0,32.18,3.4713452,32.18,3.4713452,32.18,3.4713452,3.4713452,3.4713452,0,0.49503002,0,32.18,3.4713452,32.18,3.4713452
Hong Kong,HKG,2.7013612,8.1400003,10.05,0,0,0,0.24609999,0,1,0,0,0.0,1,1,14.9,2.7013612,14.9,2.7013612,14.9,2.7013612,2.7013612,2.7013612,0,0.0,0,14.9,2.7013612,14.9,2.7013612
Honduras,HND,4.3579903,5.3200002,7.6900001,1,0,0,0.16670001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.012,0,93.25,4.535284,,
Haiti,HTI,4.8675346,3.73,7.1500001,0,0,0,0.2111,0,0,0,0,0.0,0,0,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,1.0,0,130.0,4.8675346,130.0,4.8675346
Indonesia,IDN,5.1357985,7.5900002,7.3299999,1,1,0,0.055599999,0,1,0,0,0.0,1,1,170.0,5.1357985,170.0,5.1357985,170.0,5.1357985,5.1357985,5.1357985,0,0.17873,0,170.0,5.1357985,170.0,5.1357985
India,IND,3.8842406,8.2700005,7.3299999,0,1,0,0.22220001,0,1,0,0,0.0,0,0,48.630001,3.8842406,48.630001,3.8842406,48.630001,3.8842406,3.8842406,3.8842406,0,0.23596001,0,48.630001,3.8842406,48.630001,3.8842406
Jamaica,JAM,4.8675346,7.0900002,8.1899996,0,1,0,0.2017,0,0,0,0,10.0,0,1,130.0,4.8675346,130.0,4.8675346,130.0,4.8675346,4.8675346,4.8675346,0,0.0,0,130.0,4.8675346,130.0,4.8675346
Kenya,KEN,4.9767337,6.0500002,7.0599999,0,1,1,0.0111,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.79799998,0,145.0,4.9767337,145.0,4.9767337
Sri Lanka,LKA,4.2456341,6.0500002,7.73,0,1,0,0.077799998,0,1,0,0,0.0,0,1,69.800003,4.2456341,69.800003,4.2456341,69.800003,4.2456341,4.2456341,4.2456341,0,0.138,0,69.800003,4.2456341,69.800003,4.2456341
Morocco,MAR,4.3592696,7.0900002,8.04,1,0,0,0.3556,0,0,1,0,1.0,1,1,78.199997,4.3592696,78.199997,4.3592696,78.199997,4.3592696,4.3592696,4.3592696,0,0.0,0,78.199997,4.3592696,78.199997,4.3592696
Madagascar,MDG,6.2842088,4.4499998,6.8400002,1,1,0,0.22220001,0,0,1,0,0.0,1,1,536.03998,6.2842088,536.03998,6.2842088,536.03998,6.2842088,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Mexico,MEX,4.2626801,7.5,8.9399996,1,1,0,0.25560001,0,0,0,0,15.000001,1,1,71.0,4.2626801,71.0,4.2626801,71.0,4.2626801,4.2626801,4.2626801,0,0.00042,0,71.0,4.2626801,71.0,4.2626801
Mali,MLI,7.986165,4.0,6.5700002,1,1,0,0.18889999,0,0,1,0,0.0,1,1,2940.0,7.986165,2940.0,7.986165,2940.0,7.986165,5.521461,5.521461,1,0.94050002,0,250.0,5.521461,250.0,5.521461
Malta,MLT,2.7911651,7.23,9.4300003,0,1,0,0.3944,0,0,0,1,100.0,0,0,16.299999,2.7911651,16.299999,2.7911651,16.299999,2.7911651,2.7911651,2.7911651,0,,0,16.299999,2.7911651,16.299999,2.7911651
Malaysia,MYS,2.8735647,7.9499998,8.8900003,0,1,0,0.025599999,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.23331,0,17.700001,2.8735647,17.700001,2.8735647
Niger,NER,5.9914646,5.0,6.73,1,0,0,0.1778,0,0,1,0,0.0,1,1,400.0,5.9914646,400.0,5.9914646,400.0,5.9914646,5.521461,5.521461,1,0.94050002,1,250.0,5.521461,250.0,5.521461
Nigeria,NGA,7.6029005,5.5500002,6.8099999,1,1,0,0.1111,0,0,1,0,0.0,1,1,2004.0,7.6029005,2004.0,7.6029005,2004.0,7.6029005,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
Nicaragua,NIC,5.0955892,5.23,7.54,1,0,0,0.1444,0,0,0,0,20.0,1,1,,,93.25,4.535284,130.0,4.8675346,5.0955892,4.8675346,0,0.044,0,93.25,4.535284,,
New Zealand,NZL,2.1459312,9.7299995,9.7600002,0,1,0,0.45559999,1,0,0,1,91.699997,1,1,8.5500002,2.1459312,8.5500002,2.1459312,8.5500002,2.1459312,2.1459312,2.1459312,0,0.0,0,8.5500002,2.1459312,8.5500002,2.1459312
Pakistan,PAK,3.6106477,6.0500002,7.3499999,1,0,0,0.33329999,0,1,0,0,0.0,1,1,36.990002,3.6106477,36.990002,3.6106477,36.990002,3.6106477,3.6106477,3.6106477,0,0.53757,0,36.990002,3.6106477,36.990002,3.6106477
Panama,PAN,5.0955892,5.9099998,8.8400002,1,0,0,0.1,0,0,0,0,20.0,1,1,15.07,2.7127061,30.5,3.4177268,130.0,4.8675346,5.0955892,4.8675346,0,0.08004,0,30.5,3.4177268,15.07,2.7127061
Peru,PER,4.2626801,5.77,8.3999996,1,0,0,0.1111,0,0,0,0,30.000002,1,1,15.07,2.7127061,30.5,3.4177268,56.5,4.0342407,4.2626801,4.0342407,0,0.00050000002,0,30.5,3.4177268,15.07,2.7127061
Paraguay,PRY,4.3579903,6.9499998,8.21,1,0,0,0.25560001,0,0,0,0,25.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Sudan,SDN,4.4796071,4.0,7.3099999,1,1,0,0.16670001,0,0,1,0,0.0,1,1,88.199997,4.4796071,88.199997,4.4796071,88.199997,4.4796071,4.4796071,4.4796071,1,0.93099999,0,88.199997,4.4796071,88.199997,4.4796071
Senegal,SEN,5.1038828,6.0,7.4000001,0,1,0,0.1556,0,0,1,0,0.0,0,1,164.66,5.1038828,164.66,5.1038828,164.66,5.1038828,5.1038828,5.1038828,1,0.94999999,0,164.66,5.1038828,164.66,5.1038828
Singapore,SGP,2.8735647,9.3199997,10.15,0,0,0,0.0136,0,1,0,0,0.0,0,1,17.700001,2.8735647,17.700001,2.8735647,17.700001,2.8735647,2.8735647,2.8735647,0,0.0,0,17.700001,2.8735647,17.700001,2.8735647
Sierra Leone,SLE,6.1800165,5.8200002,6.25,1,1,0,0.092200004,0,0,1,0,0.0,1,1,483.0,6.1800165,483.0,6.1800165,483.0,6.1800165,5.521461,5.521461,1,0.94999999,0,250.0,5.521461,250.0,5.521461
El Salvador,SLV,4.3579903,5.0,7.9499998,1,0,0,0.15000001,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0,0,93.25,4.535284,,
Togo,TGO,6.5042882,6.9099998,7.2199998,1,0,0,0.0889,0,0,1,0,0.0,1,1,668.0,6.5042882,668.0,6.5042882,668.0,6.5042882,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Trinidad and Tobago,TTO,4.4426513,7.4499998,8.7700005,0,1,0,0.1222,0,0,0,0,40.0,0,1,85.0,4.4426513,85.0,4.4426513,85.0,4.4426513,4.4426513,4.4426513,0,0.0,0,85.0,4.4426513,85.0,4.4426513
Tunisia,TUN,4.1431346,6.4499998,8.4799995,1,1,0,0.37779999,0,0,1,0,0.0,1,1,63.0,4.1431346,63.0,4.1431346,63.0,4.1431346,4.1431346,4.1431346,0,0.0,0,63.0,4.1431346,63.0,4.1431346
Tanzania,TZA,4.9767337,6.6399999,6.25,0,0,1,0.066699997,0,0,1,0,0.0,0,0,145.0,4.9767337,145.0,4.9767337,145.0,4.9767337,4.9767337,4.9767337,1,0.92150003,1,145.0,4.9767337,145.0,4.9767337
Uganda,UGA,5.6347895,4.4499998,6.9699998,1,0,0,0.0111,0,0,1,0,0.0,1,1,280.0,5.6347895,280.0,5.6347895,280.0,5.6347895,5.521461,5.521461,1,0.94999999,1,250.0,5.521461,250.0,5.521461
Uruguary,URY,4.2626801,7.0,9.0299997,1,0,0,0.36669999,0,0,0,0,90.0,1,1,,,93.25,4.535284,56.5,4.0342407,4.2626801,4.0342407,0,0.0,0,93.25,4.535284,,
USA,USA,2.7080503,10.0,10.22,0,1,0,0.42219999,1,0,0,0,83.600006,0,1,15.0,2.7080503,15.0,2.7080503,15.0,2.7080503,2.7080503,2.7080503,0,0.0,0,15.0,2.7080503,15.0,2.7080503
Venezuela,VEN,4.3579903,7.1399999,9.0699997,1,0,0,0.0889,0,0,0,0,20.0,1,1,,,93.25,4.535284,62.200001,4.1303549,4.3579903,4.1303549,0,0.0070400001,0,93.25,4.535284,,
Vietnam,VNM,4.9416423,6.4099998,7.2800002,1,1,0,0.1778,0,1,0,0,0.0,1,1,140.0,4.9416423,140.0,4.9416423,140.0,4.9416423,4.9416423,4.9416423,0,0.70109999,0,140.0,4.9416423,140.0,4.9416423
South Africa,ZAF,2.74084,6.8600001,8.8900003,0,1,0,0.3222,0,0,1,0,16.0,0,1,15.5,2.74084,15.5,2.74084,15.5,2.74084,2.74084,2.74084,0,0.1045,0,15.5,2.74084,15.5,2.74084
Zaire,ZAR,5.480639,3.5,6.8699999,0,0,1,0.0,0,0,1,0,0.0,0,0,240.0,5.480639,240.0,5.480639,240.0,5.480639,5.480639,5.480639,1,0.94999999,1,240.0,5.480639,240.0,5.480639
1 change: 1 addition & 0 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"sc": {"filename": "synthetic_control.csv"},
"anova1": {"filename": "ancova_generated.csv"},
"geolift1": {"filename": "geolift1.csv"},
"risk": {"filename": "AJR2001.csv"},
}


Expand Down
128 changes: 128 additions & 0 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import seaborn as sns
import xarray as xr
from patsy import build_design_matrices, dmatrices
from sklearn.linear_model import LinearRegression as sk_lin_reg

from causalpy.custom_exceptions import BadIndexException # NOQA
from causalpy.custom_exceptions import DataException, FormulaException
Expand Down Expand Up @@ -883,3 +884,130 @@ def _get_treatment_effect_coeff(self) -> str:
return label

raise NameError("Unable to find coefficient name for the treatment effect")


class InstrumentalVariable(ExperimentalDesign):
"""
A class to analyse instrumental variable style experiments.
:param instruments_data: A pandas dataframe of instruments
for our treatment variable. Should contain
instruments Z, and treatment t
:param data: A pandas dataframe of covariates for fitting
the focal regression of interest. Should contain covariates X
including treatment t and outcome y
:param instruments_formula: A statistical model formula for
the instrumental stage regression
e.g. t ~ 1 + z1 + z2 + z3
:param formula: A statistical model formula for the \n
focal regression e.g. y ~ 1 + t + x1 + x2 + x3
:param model: A PyMC model
:param priors: An optional dictionary of priors for the
mus and sigmas of both regressions. If priors are not
specified we will substitue MLE estimates for the beta
coefficients. Greater control can be achieved
by specifying the priors directly e.g. priors = {
"mus": [0, 0],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 2,
}
"""

def __init__(
self,
instruments_data: pd.DataFrame,
data: pd.DataFrame,
instruments_formula: str,
formula: str,
model=None,
priors=None,
**kwargs,
):
super().__init__(model=model, **kwargs)
self.expt_type = "Instrumental Variable Regression"
self.data = data
self.instruments_data = instruments_data
self.formula = formula
self.instruments_formula = instruments_formula
self.model = model
self._input_validation()

y, X = dmatrices(formula, self.data)
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]

t, Z = dmatrices(instruments_formula, self.instruments_data)
self._t_design_info = t.design_info
self._z_design_info = Z.design_info
self.labels_instruments = Z.design_info.column_names
self.t, self.Z = np.asarray(t), np.asarray(Z)
self.instrument_variable_name = t.design_info.column_names[0]

self.get_naive_OLS_fit()
self.get_2SLS_fit()

# fit the model to the data
COORDS = {"instruments": self.labels_instruments, "covariates": self.labels}
self.coords = COORDS
if priors is None:
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 2,
}
self.priors = priors
self.model.fit(
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
)

def get_2SLS_fit(self):
first_stage_reg = sk_lin_reg().fit(self.Z, self.t)
fitted_Z_values = first_stage_reg.predict(self.Z)
X2 = self.data.copy(deep=True)
X2[self.instrument_variable_name] = fitted_Z_values
_, X2 = dmatrices(self.formula, X2)
second_stage_reg = sk_lin_reg().fit(X=X2, y=self.y)
betas_first = list(first_stage_reg.coef_[0][1:])
betas_first.insert(0, first_stage_reg.intercept_[0])
betas_second = list(second_stage_reg.coef_[0][1:])
betas_second.insert(0, second_stage_reg.intercept_[0])
self.ols_beta_first_params = betas_first
self.ols_beta_second_params = betas_second
self.first_stage_reg = first_stage_reg
self.second_stage_reg = second_stage_reg

def get_naive_OLS_fit(self):
ols_reg = sk_lin_reg().fit(self.X, self.y)
beta_params = list(ols_reg.coef_[0][1:])
beta_params.insert(0, ols_reg.intercept_[0])
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
self.ols_reg = ols_reg

def _input_validation(self):
"""Validate the input data and model formula for correctness"""
treatment = self.instruments_formula.split("~")[0]
test = treatment.strip() in self.instruments_data.columns
test = test & (treatment.strip() in self.data.columns)
if not test:
raise DataException(
f"""
The treatment variable:
{treatment} must appear in the instrument_data to be used
as an outcome variable and in the data object to be used as a covariate.
"""
)
Z = self.data[treatment.strip()]
check_binary = len(np.unique(Z)) > 2
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
)
Loading

0 comments on commit 3070bc9

Please sign in to comment.